From e16f4697c35b36371e7e660df09f24644a669c95 Mon Sep 17 00:00:00 2001 From: Martin Teichmann <martin.teichmann@xfel.eu> Date: Mon, 28 Dec 2020 20:09:13 +0000 Subject: [PATCH] add reverse binary and negation --- ebpf.py | 79 ++++++++++++++++++++++++++++++++++++++++++++-------- ebpf_test.py | 29 ++++++++++++++++++- 2 files changed, 96 insertions(+), 12 deletions(-) diff --git a/ebpf.py b/ebpf.py index ed0b5cc..4307feb 100644 --- a/ebpf.py +++ b/ebpf.py @@ -242,22 +242,32 @@ class InvertComparison(Comparison): self.value.compare(not negative) -def binary(opcode, symetric=False): +def binary(opcode): def ret(self, value): return Binary(self.ebpf, self, value, opcode) return ret +def rbinary(opcode): + def ret(self, value): + return ReverseBinary(self.ebpf, value, self, opcode) + return ret + class Expression: - __radd__ = __add__ = binary(Opcode.ADD, True) + __radd__ = __add__ = binary(Opcode.ADD) __sub__ = binary(Opcode.SUB) - __rmul__ = __mul__ = binary(Opcode.MUL, True) + __rsub__ = rbinary(Opcode.SUB) + __rmul__ = __mul__ = binary(Opcode.MUL) __truediv__ = binary(Opcode.DIV) - __ror__ = __or__ = binary(Opcode.OR, True) + __rtruediv__ = rbinary(Opcode.DIV) + __ror__ = __or__ = binary(Opcode.OR) __lshift__ = binary(Opcode.LSH) + __rlshift__ = rbinary(Opcode.LSH) __rshift__ = binary(Opcode.RSH) + __rrshift__ = rbinary(Opcode.RSH) __mod__ = binary(Opcode.MOD) - __rxor__ = __xor__ = binary(Opcode.XOR, True) + __rmod__ = rbinary(Opcode.MOD) + __rxor__ = __xor__ = binary(Opcode.XOR) __eq__ = comparison(Opcode.JEQ, Opcode.JNE) __gt__ = comparison(Opcode.JGT, Opcode.JLE, Opcode.JSGT, Opcode.JSLE) @@ -271,6 +281,9 @@ class Expression: __rand__ = __and__ + def __neg__(self): + return Negate(self.ebpf, self) + class Binary(Expression): def __init__(self, ebpf, left, right, operator): @@ -316,6 +329,49 @@ class Binary(Expression): and self.right.contains(no)) +class ReverseBinary(Expression): + def __init__(self, ebpf, left, right, operator): + self.ebpf = ebpf + self.left = left + self.right = right + self.operator = operator + + def calculate(self, dst, long, signed, force=False): + if dst is None: + dst = self.ebpf.get_free_register() + self.ebpf.owners.add(dst) + free = True + else: + free = False + self.ebpf._load_value(dst, self.left) + if self.operator is Opcode.RSH and self.left < 0: # >>= + operator = Opcode.ARSH + else: + operator = self.operator + + src, long, _, rfree = self.right.calculate(None, long, None) + self.ebpf.append(operator + Opcode.LONG * long + Opcode.REG, + dst, src, 0, 0) + return dst, long, signed, free + + def contains(self, no): + return self.right.contains(no) + + +class Negate(Expression): + def __init__(self, ebpf, arg): + self.ebpf = ebpf + self.arg = arg + + def calculate(self, dst, long, signed, force=False): + dst, long, signed, free = self.arg.calculate(dst, long, signed, force) + self.ebpf.append(Opcode.NEG + Opcode.LONG * long, dst, 0, 0, 0) + return dst, long, signed, free + + def contains(self, no): + return self.arg.contains(no) + + class Sum(Binary): def __init__(self, ebpf, left, right): super().__init__(ebpf, left, right, Opcode.ADD) @@ -496,12 +552,7 @@ class RegisterDesc: def __set__(self, instance, value): instance.owners.add(self.no) if isinstance(value, int): - if -0x80000000 <= value < 0x80000000: - instance.append(Opcode.MOV + Opcode.LONG * self.long, - self.no, 0, 0, value) - else: - instance.append(Opcode.DW, self.no, 0, 0, value & 0xffffffff) - instance.append(Opcode.W, 0, 0, 0, value >> 32) + instance._load_value(self.no, value) elif isinstance(value, Expression): value.calculate(self.no, self.long, self.signed, True) elif isinstance(value, Instruction): @@ -571,6 +622,12 @@ class EBPF: return i raise AssembleError("not enough registers") + def _load_value(self, no, value): + if -0x80000000 <= value < 0x80000000: + self.append(Opcode.MOV + Opcode.LONG, no, 0, 0, value) + else: + self.append(Opcode.DW, no, 0, 0, value & 0xffffffff) + self.append(Opcode.W, 0, 0, 0, value >> 32) for i in range(11): setattr(EBPF, f"r{i}", RegisterDesc(i, True)) diff --git a/ebpf_test.py b/ebpf_test.py index 015e4fb..58c51d1 100644 --- a/ebpf_test.py +++ b/ebpf_test.py @@ -48,7 +48,7 @@ class Tests(TestCase): e.w2 += 3 e.w5 += e.w6 self.assertEqual(e.opcodes, - [Instruction(0xb4, 3, 0, 0, 7), + [Instruction(O.MOV+O.LONG, 3, 0, 0, 7), Instruction(0xbc, 4, 1, 0, 0), Instruction(opcode=4, dst=2, src=0, off=0, imm=3), Instruction(opcode=0xc, dst=5, src=6, off=0, imm=0)]) @@ -323,6 +323,33 @@ class Tests(TestCase): Instruction(opcode=191, dst=0, src=1, off=0, imm=0), Instruction(opcode=95, dst=0, src=2, off=0, imm=0)]) + def test_reverse_binary(self): + e = EBPF() + e.owners = {0, 1, 2, 3} + e.r3 = 7 / e.r2 + e.r3 = 7 << e.r2 + e.r3 = 7 % e.r2 + e.r3 = 7 >> e.r2 + e.r3 = -7 >> e.r2 + self.assertEqual(e.opcodes, [ + Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=7), + Instruction(opcode=O.REG+O.LONG+O.DIV, dst=3, src=2, off=0, imm=0), + Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=7), + Instruction(opcode=O.LSH+O.REG+O.LONG, dst=3, src=2, off=0, imm=0), + Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=7), + Instruction(opcode=O.REG+O.MOD+O.LONG, dst=3, src=2, off=0, imm=0), + Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=7), + Instruction(opcode=O.REG+O.RSH+O.LONG, dst=3, src=2, off=0, imm=0), + Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=-7), + Instruction(opcode=O.REG+O.LONG+O.ARSH, dst=3, src=2, off=0, imm=0) + ]) + + def test_reverse_binary(self): + e = EBPF() + e.r7 = -e.r1 + self.assertEqual(e.opcodes, [ + Instruction(opcode=O.LONG+O.REG+O.MOV, dst=7, src=1, off=0, imm=0), + Instruction(opcode=O.LONG+O.NEG, dst=7, src=0, off=0, imm=0)]) def test_jump_data(self): e = EBPF() -- GitLab