diff --git a/ebpf.py b/ebpf.py index 96209e71a205398cb57b2fcc1d25db029c38054e..a95d5d6b4dc615190bb3568abccfa596cb171ee2 100644 --- a/ebpf.py +++ b/ebpf.py @@ -37,8 +37,9 @@ class Comparison: self.left = left self.right = right self.opcode = opcode + self.invert = None - def calculate(self, negative): + def compare(self, negative): self.dst, _, lsigned, lfree = self.left.calculate(None, None, None) if not isinstance(self.right, int): self.src, _, rsigned, rfree = \ @@ -75,10 +76,24 @@ class Comparison: def __exit__(self, exc_type, exc, tb): self.target() + if self.invert is not None: + olen = len(self.ebpf.opcodes) + assert self.ebpf.opcodes[self.invert].opcode == 5 + self.ebpf.opcodes[self.invert:self.invert] = \ + self.ebpf.opcodes[self.origin+1:] + del self.ebpf.opcodes[olen-1:] + op, dst, src, off, imm = self.ebpf.opcodes[self.invert - 1] + self.ebpf.opcodes[self.invert - 1] = \ + Instruction(op, dst, src, + len(self.ebpf.opcodes) - self.origin + 1, imm) def Else(self): op, dst, src, off, imm = self.ebpf.opcodes[self.origin] - self.ebpf.opcodes[self.origin] = Instruction(op, dst, src, off+1, imm) + if op == 5: + self.invert = self.origin + else: + self.ebpf.opcodes[self.origin] = \ + Instruction(op, dst, src, off+1, imm) self.origin = len(self.ebpf.opcodes) self.ebpf.opcodes.append(None) self.right = self.dst = 0 @@ -86,10 +101,14 @@ class Comparison: return self +class OrComparison(Comparison): + def compare(self, negative): + self.left.compare(negative) + self.right.compare(negative) + + def binary(opcode, symetric=False): def ret(self, value): - #if symetric and isinstance(value, Register): - # return Binary(self.ebpf, value, self, opcode) return Binary(self.ebpf, self, value, opcode) return ret @@ -100,7 +119,6 @@ class Expression: __rmul__ = __mul__ = binary(0x24, True) __truediv__ = binary(0x34) __ror__ = __or__ = binary(0x44, True) - __rand__ = __and__ = binary(0x54, True) __lshift__ = binary(0x64) __rshift__ = binary(0x74) __mod__ = binary(0x94) @@ -112,7 +130,11 @@ class Expression: __lt__ = comparison(0xa5, 0x35, 0xc5, 0x75) __le__ = comparison(0xb5, 0x25, 0xd5, 0x65) __ne__ = comparison(0x55, 0x15) - __and__ = __rand__ = comparison(0x45, None) + + def __and__(self, value): + return AndExpression(self.ebpf, self, value) + + __rand__ = __and__ class Binary(Expression): @@ -164,6 +186,25 @@ class Sum(Binary): return super().__sub__(value) +class AndExpression(Binary, Comparison): + __and__ = __rand__ = comparison(0x45, None) + __rand__ = __and__ = binary(0x54, True) + + def __init__(self, ebpf, left, right): + Binary.__init__(self, ebpf, left, right, 0x54) + Comparison.__init__(self, ebpf, left, right, 0x45) + self.opcode = (0x45, None, 0x45, None) + + def compare(self, negative): + super().compare(False) + if negative: + origin = len(self.ebpf.opcodes) + self.ebpf.opcodes.append(None) + self.target() + self.origin = origin + self.right = self.dst = 0 + self.opcode = 5 + class Register(Expression): offset = 0 @@ -347,7 +388,7 @@ class EBPF: log_level, log_size, self.kern_version) def jumpIf(self, comp): - comp.calculate(False) + comp.compare(False) return comp def jump(self): @@ -360,11 +401,7 @@ class EBPF: return comp def If(self, comp): - comp.calculate(True) - return comp - - def isZero(self, comp): - comp.calculate(False) + comp.compare(True) return comp def get_fd(self, fd): diff --git a/ebpf_test.py b/ebpf_test.py index 91cfcd8387c143e6f851125e7e7634ee1a70080e..a50dbc69ef4c6a396df696c3498b65f8dc7a4697 100644 --- a/ebpf_test.py +++ b/ebpf_test.py @@ -211,6 +211,25 @@ class Tests(TestCase): Instruction(opcode=0x5, dst=0, src=0, off=1, imm=0), Instruction(opcode=0xb7, dst=6, src=0, off=0, imm=7)]) + def test_with_inversion(self): + e = EBPF() + with e.If(e.r1 & 1) as cond: + e.r0 = 2 + with e.If(e.r1 & 7) as cond: + e.r0 = 2 + e.r1 = 4 + with cond.Else(): + e.r0 = 3 + self.assertEqual(e.opcodes, [ + Instruction(opcode=69, dst=1, src=0, off=1, imm=1), + Instruction(opcode=5, dst=0, src=0, off=1, imm=0), + Instruction(opcode=183, dst=0, src=0, off=0, imm=2), + Instruction(opcode=69, dst=1, src=0, off=2, imm=7), + Instruction(opcode=183, dst=0, src=0, off=0, imm=3), + Instruction(opcode=5, dst=0, src=0, off=2, imm=0), + Instruction(opcode=183, dst=0, src=0, off=0, imm=2), + Instruction(opcode=183, dst=1, src=0, off=0, imm=4)]) + def test_comp_binary(self): e = EBPF() e.owners = {1, 2, 3, 5} @@ -256,6 +275,7 @@ class Tests(TestCase): e.sr0 = e.sr1 >> 2 e.sr0 = e.sr1 >> e.r2 e.w0 = e.w1 + e.w2 + e.r0 = e.r1 & e.r2 # attention, special case self.assertEqual(e.opcodes, [ Instruction(opcode=191, dst=0, src=1, off=0, imm=0), Instruction(opcode=47, dst=0, src=2, off=0, imm=0), @@ -277,7 +297,9 @@ class Tests(TestCase): Instruction(opcode=191, dst=0, src=1, off=0, imm=0), Instruction(opcode=207, dst=0, src=2, off=0, imm=0), Instruction(opcode=188, dst=0, src=1, off=0, imm=0), - Instruction(opcode=12, dst=0, src=2, off=0, imm=0)]) + Instruction(opcode=12, dst=0, src=2, off=0, imm=0), + Instruction(opcode=191, dst=0, src=1, off=0, imm=0), + Instruction(opcode=95, dst=0, src=2, off=0, imm=0)]) def test_jump_data(self): @@ -372,12 +394,11 @@ class Tests(TestCase): class KernelTests(TestCase): def test_minimal(self): e = EBPF(ProgType.XDP, "GPL") - e.r3 = 3 - e.r4 = 5 - e.r5 = 7 - tgt = e.jumpIf(e.r3 < e.r4 + e.r5) - e.r0 = 8 - tgt.target() + with e.If(e.r1 & 1111111) as cond: + e.r0 = 2 + e.r1 = 4 + with cond.Else(): + e.r0 = 3 e.exit() print(e.load(log_level=1)[1]) self.fail()