diff --git a/ebpf.py b/ebpf.py index e085ca28ca6245d699cb2ecfe9f2f6bfe1a90315..96209e71a205398cb57b2fcc1d25db029c38054e 100644 --- a/ebpf.py +++ b/ebpf.py @@ -22,42 +22,55 @@ def augassign(opcode): def comparison(uposop, unegop, sposop=None, snegop=None): + if sposop is None: + sposop = uposop + snegop = unegop def ret(self, value): - if self.signed and sposop is not None: - return Comparison(self.no, value, sposop, snegop) - else: - return Comparison(self.no, value, uposop, unegop) + return Comparison(self.ebpf, self, value, + (uposop, unegop, sposop, snegop)) return ret class Comparison: - def __init__(self, dst, src, posop, negop): - self.dst = dst - self.src = src - self.posop = posop - self.negop = negop + def __init__(self, ebpf, left, right, opcode): + self.ebpf = ebpf + self.left = left + self.right = right + self.opcode = opcode + + def calculate(self, negative): + self.dst, _, lsigned, lfree = self.left.calculate(None, None, None) + if not isinstance(self.right, int): + self.src, _, rsigned, rfree = \ + self.right.calculate(None, None, None) + if rsigned != rsigned: + raise AssembleError("need same signedness for comparison") + else: + rfree = False + self.origin = len(self.ebpf.opcodes) + self.ebpf.opcodes.append(None) + self.opcode = self.opcode[negative + 2 * lsigned] + if lfree: + self.ebpf.owners.discard(self.dst) + if rfree: + self.ebpf.owners.discard(self.src) + self.owners = self.ebpf.owners.copy() def target(self): assert self.ebpf.opcodes[self.origin] is None - if isinstance(self.src, int): + if isinstance(self.right, int): inst = Instruction( self.opcode, self.dst, 0, - len(self.ebpf.opcodes) - self.origin - 1, self.src) - elif isinstance(self.src, Register): + len(self.ebpf.opcodes) - self.origin - 1, self.right) + else: inst = Instruction( - self.opcode + 8, self.dst, self.src.no, + self.opcode + 8, self.dst, self.src, len(self.ebpf.opcodes) - self.origin - 1, 0) - else: - return NotImplemented self.ebpf.opcodes[self.origin] = inst self.ebpf.owners, self.owners = \ self.ebpf.owners & self.owners, self.ebpf.owners def __enter__(self): - self.origin = len(self.ebpf.opcodes) - self.ebpf.opcodes.append(None) - if self.opcode != 5: # Else branch - self.owners = self.ebpf.owners.copy() return self def __exit__(self, exc_type, exc, tb): @@ -66,7 +79,9 @@ class Comparison: def Else(self): op, dst, src, off, imm = self.ebpf.opcodes[self.origin] self.ebpf.opcodes[self.origin] = Instruction(op, dst, src, off+1, imm) - self.src = self.dst = 0 + self.origin = len(self.ebpf.opcodes) + self.ebpf.opcodes.append(None) + self.right = self.dst = 0 self.opcode = 5 return self @@ -91,6 +106,14 @@ class Expression: __mod__ = binary(0x94) __rxor__ = __xor__ = binary(0xa4, True) + __eq__ = comparison(0x15, 0x55) + __gt__ = comparison(0x25, 0xb5, 0x65, 0xd5) + __ge__ = comparison(0x35, 0xa5, 0x75, 0xc5) + __lt__ = comparison(0xa5, 0x35, 0xc5, 0x75) + __le__ = comparison(0xb5, 0x25, 0xd5, 0x65) + __ne__ = comparison(0x55, 0x15) + __and__ = __rand__ = comparison(0x45, None) + class Binary(Expression): def __init__(self, ebpf, left, right, operator): @@ -115,7 +138,7 @@ class Binary(Expression): self.ebpf.append(operator + (3 if long is None else 3 * long), dst, 0, 0, self.right) else: - src, long, signed, rfree = self.right.calculate(None, long, signed) + src, long, _, rfree = self.right.calculate(None, long, None) self.ebpf.append(operator + 3 * long + 8, dst, src, 0, 0) if rfree: self.ebpf.owners.discard(src) @@ -184,25 +207,18 @@ class Register(Expression): else: return super().__sub__(value) - __eq__ = comparison(0x15, 0x55) - __gt__ = comparison(0x25, 0xb5, 0x65, 0xd5) - __ge__ = comparison(0x35, 0xa5, 0x75, 0xc5) - __lt__ = comparison(0xa5, 0x35, 0xc5, 0x75) - __le__ = comparison(0xb5, 0x25, 0xd5, 0x65) - __ne__ = comparison(0x55, 0x15) - __and__ = __rand__ = comparison(0x45, None) - - def calculate(self, dst, long, signed, force=False): if long is not None and long != self.long: raise AssembleError("cannot compile") + if signed is not None and signed != self.signed: + raise AssembleError("cannot compile") if self.no not in self.ebpf.owners: raise AssembleError("register has no value") if dst != self.no and force: self.ebpf.append(0xbc + 3 * self.long, dst, self.no, 0, 0) - return dst, self.long, signed, False + return dst, self.long, self.signed, False else: - return self.no, self.long, signed, False + return self.no, self.long, self.signed, False class Memory(Expression): @@ -331,31 +347,24 @@ class EBPF: log_level, log_size, self.kern_version) def jumpIf(self, comp): - comp.origin = len(self.opcodes) - comp.ebpf = self - comp.owners = self.owners.copy() - comp.opcode = comp.posop - self.opcodes.append(None) + comp.calculate(False) return comp def jump(self): - comp = Comparison(0, 0, None, None) + comp = Comparison(self, None, 0, 5) comp.origin = len(self.opcodes) - comp.ebpf = self + comp.dst = 0 comp.owners = self.owners.copy() self.owners = set(range(11)) - comp.opcode = 5 self.opcodes.append(None) return comp def If(self, comp): - comp.opcode = comp.negop - comp.ebpf = self + comp.calculate(True) return comp def isZero(self, comp): - comp.opcode = comp.negop - comp.ebpf = self + comp.calculate(False) return comp def get_fd(self, fd): diff --git a/ebpf_test.py b/ebpf_test.py index 7f3145afeeac4d2da800e23bf8956f22c996d60f..91cfcd8387c143e6f851125e7e7634ee1a70080e 100644 --- a/ebpf_test.py +++ b/ebpf_test.py @@ -112,6 +112,7 @@ class Tests(TestCase): def test_jump(self): e = EBPF() + e.owners = set(range(11)) target = e.jump() e.r0 = 1 target.target() @@ -199,6 +200,7 @@ class Tests(TestCase): def test_with(self): e = EBPF() + e.owners = set(range(11)) with e.If(e.r2 > 3) as cond: e.r2 = 5 with cond.Else(): @@ -209,6 +211,30 @@ class Tests(TestCase): Instruction(opcode=0x5, dst=0, src=0, off=1, imm=0), Instruction(opcode=0xb7, dst=6, src=0, off=0, imm=7)]) + def test_comp_binary(self): + e = EBPF() + e.owners = {1, 2, 3, 5} + with e.If(e.r1 + e.r3 > 3) as cond: + e.r0 = 5 + with cond.Else(): + e.r0 = 7 + + tgt = e.jumpIf(e.r0 < e.r2 + e.r5) + e.r0 = 8 + tgt.target() + + self.assertEqual(e.opcodes, [ + Instruction(opcode=191, dst=0, src=1, off=0, imm=0), + Instruction(opcode=15, dst=0, src=3, off=0, imm=0), + Instruction(opcode=181, dst=0, src=0, off=2, imm=3), + Instruction(opcode=183, dst=0, src=0, off=0, imm=5), + Instruction(opcode=5, dst=0, src=0, off=1, imm=0), + Instruction(opcode=183, dst=0, src=0, off=0, imm=7), + Instruction(opcode=191, dst=4, src=2, off=0, imm=0), + Instruction(opcode=15, dst=4, src=5, off=0, imm=0), + Instruction(opcode=173, dst=0, src=4, off=1, imm=0), + Instruction(opcode=183, dst=0, src=0, off=0, imm=8)]) + def test_huge(self): e = EBPF() e.r3 = 0x1234567890 @@ -346,12 +372,12 @@ class Tests(TestCase): class KernelTests(TestCase): def test_minimal(self): e = EBPF(ProgType.XDP, "GPL") - e.r2 = 2 - e.r3 = -16 - e.r4 = 4 - e.r5 = 5 - e.m32[e.r10 - 16] = 0 - e.r5 = (e.r2 * e.r3) + e.m32[e.r10 + e.r3] + e.r3 = 3 + e.r4 = 5 + e.r5 = 7 + tgt = e.jumpIf(e.r3 < e.r4 + e.r5) + e.r0 = 8 + tgt.target() e.exit() print(e.load(log_level=1)[1]) self.fail()