diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py index 106f0be033a43d8477113566d7f9fcbda482325d..cdceca590bb6c5cedac4fb9d21f5d606fcd28242 100644 --- a/ebpfcat/ebpf.py +++ b/ebpfcat/ebpf.py @@ -454,6 +454,7 @@ class Expression: __ror__ = __or__ = lambda self, value: self._binary(value, Opcode.OR) __lshift__ = lambda self, value: self._binary(value, Opcode.LSH) + __rlshift__ = lambda self, value: Constant(self.ebpf, value) << self __rxor__ = __xor__ = lambda self, value: self._binary(value, Opcode.XOR) __gt__ = comparison(Opcode.JGT, Opcode.JLE, Opcode.JSGT, Opcode.JSLE) @@ -474,24 +475,11 @@ class Expression: return Binary(self.ebpf, myself, value, opcode, self.signed or value.signed, self.fixed or value.fixed) - def _rsum(self, value, opcode): - value = ensure_expression(self.ebpf, value) - myself = self - if self.fixed != value.fixed: - if self.fixed: - value *= self.FIXED_BASE - else: - myself *= self.FIXED_BASE - - return ReverseBinary( - self.ebpf, value, myself, opcode, - self.signed or value.signed, self.fixed or value.fixed) - __radd__ = __add__ = lambda self, value: self._sum(value, Opcode.ADD) __sub__ = lambda self, value: self._sum(value, Opcode.SUB) - __rsub__ = lambda self, value: self._rsum(value, Opcode.SUB) + __rsub__ = lambda self, value: Constant(self.ebpf, value) - self __mod__ = lambda self, value: self._sum(value, Opcode.MOD) - __rmod__ = lambda self, value: self._rsum(value, Opcode.MOD) + __rmod__ = lambda self, value: Constant(self.ebpf, value) % self def __mul__(self, value): value = ensure_expression(self.ebpf, value) @@ -513,14 +501,10 @@ class Expression: return Binary(self.ebpf, myself, value, Opcode.DIV, self.signed or value.signed, True) - def __rtruediv__(self, value): - value = ensure_expression(self.ebpf, value) - if self.fixed: - value *= self.FIXED_BASE ** 2 - else: - value *= self.FIXED_BASE - return ReverseBinary(self.ebpf, value, self, Opcode.DIV, - self.signed or value.signed, True) + def _reverse(self, op, value): + return op(Constant(self.ebpf, value), self) + + __rtruediv__ = lambda self, value: Constant(self.ebpf, value) / self def __floordiv__(self, value): value = ensure_expression(self.ebpf, value) @@ -534,25 +518,22 @@ class Expression: self.signed or value.signed, False) def __rfloordiv__(self, value): - value = ensure_expression(self.ebpf, value) if self.fixed: - value *= self.FIXED_BASE - return ReverseBinary(self.ebpf, value, self, Opcode.DIV, - self.signed or value.signed, False) + value = Constant(self.ebpf, value) + if not value.fixed: + value *= self.FIXED_BASE + else: + value = Constant(self.ebpf, int(value)) + + return Binary(self.ebpf, value, self, Opcode.DIV, + self.signed or value.signed, False) def __rshift__(self, value): opcode = Opcode.ARSH if self.signed else Opcode.RSH return Binary(self.ebpf, self, ensure_expression(self.ebpf, value), opcode, self.signed, False) - def __rrshift__(self, value): - opcode = Opcode.ARSH if value < 0 else Opcode.RSH - return ReverseBinary(self.ebpf, Constant(self.ebpf, value), self, - opcode, value < 0, False) - - def __rlshift__(self, value): - return ReverseBinary(self.ebpf, Constant(self.ebpf, value), self, - Opcode.LSH, value < 0, False) + __rrshift__ = lambda self, value: Constant(self.ebpf, value) >> self def __and__(self, value): return AndExpression(self.ebpf, self, @@ -674,27 +655,6 @@ class Binary(Expression): and self.right.contains(no)) -class ReverseBinary(Expression): - def __init__(self, ebpf, left, right, operator, signed, fixed): - self.ebpf = ebpf - self.left = left - self.right = right - self.operator = operator - self.signed = signed - self.fixed = fixed - - @contextmanager - def calculate(self, dst, long, force=False): - with self.left.calculate(dst, long) as (dst, _): - with self.right.calculate(None, long) as (src, long): - self.ebpf.append(self.operator + Opcode.LONG * long - + Opcode.REG, dst, src, 0, 0) - yield dst, long - - def contains(self, no): - return self.right.contains(no) - - class Negate(Expression): def __init__(self, ebpf, arg): self.ebpf = ebpf diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py index 01b61ac2a1fa07057a092470390a35c4e1e203b2..3ee92bbdad6b3403e58c0a58b6ec07129f91f91e 100644 --- a/ebpfcat/ebpf_test.py +++ b/ebpfcat/ebpf_test.py @@ -192,6 +192,17 @@ class Tests(TestCase): e.r1 = e.r2 // e.x3 e.x4 = e.x5 // e.x6 + e.x1 = 3 / e.r2 + e.x3 = 3.5 / e.r4 + e.x5 = 3 / e.x6 + e.x4 = 4.5 / e.x6 + + e.x1 = 3 // e.r2 + e.x3 = 3.5 // e.r4 + e.x5 = 3 // e.x6 + e.x4 = 4.5 // e.x6 + + self.assertEqual(e.opcodes, [ Instruction(opcode=O.REG+O.MOV+O.LONG, dst=1, src=2, off=0, imm=0), Instruction(opcode=O.ADD+O.LONG, dst=1, src=0, off=0, imm=3), @@ -268,7 +279,30 @@ class Tests(TestCase): Instruction(opcode=O.DIV+O.LONG+O.REG, dst=1, src=3, off=0, imm=0), Instruction(opcode=O.LONG+O.REG+O.MOV, dst=4, src=5, off=0, imm=0), Instruction(opcode=O.DIV+O.LONG+O.REG, dst=4, src=6, off=0, imm=0), - Instruction(opcode=O.MUL+O.LONG, dst=4, src=0, off=0, imm=100000), + Instruction(opcode=O.LONG+O.MUL, dst=4, src=0, off=0, imm=100000), + + Instruction(opcode=O.LONG+O.MOV, dst=1, src=0, off=0, imm=300000), + Instruction(opcode=O.DIV+O.REG+O.LONG, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.LONG+O.MOV, dst=3, src=0, off=0, imm=350000), + Instruction(opcode=O.DIV+O.REG+O.LONG, dst=3, src=4, off=0, imm=0), + Instruction(opcode=O.DW, dst=5, src=0, off=0, imm=4230196224), + Instruction(opcode=O.W, dst=0, src=0, off=0, imm=6), + Instruction(opcode=O.DIV+O.REG+O.LONG, dst=5, src=6, off=0, imm=0), + Instruction(opcode=O.DW, dst=4, src=0, off=0, imm=2050327040), + Instruction(opcode=O.W, dst=0, src=0, off=0, imm=10), + Instruction(opcode=O.DIV+O.REG+O.LONG, dst=4, src=6, off=0, imm=0), + Instruction(opcode=O.LONG+O.MOV, dst=1, src=0, off=0, imm=3), + Instruction(opcode=O.DIV+O.REG+O.LONG, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.LONG+O.MUL, dst=1, src=0, off=0, imm=100000), + Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=3), + Instruction(opcode=O.REG+O.LONG+O.DIV, dst=3, src=4, off=0, imm=0), + Instruction(opcode=O.LONG+O.MUL, dst=3, src=0, off=0, imm=100000), + Instruction(opcode=O.LONG+O.MOV, dst=5, src=0, off=0, imm=300000), + Instruction(opcode=O.DIV+O.REG+O.LONG, dst=5, src=6, off=0, imm=0), + Instruction(opcode=O.LONG+O.MUL, dst=5, src=0, off=0, imm=100000), + Instruction(opcode=O.LONG+O.MOV, dst=4, src=0, off=0, imm=450000), + Instruction(opcode=O.DIV+O.REG+O.LONG, dst=4, src=6, off=0, imm=0), + Instruction(opcode=O.LONG+O.MUL, dst=4, src=0, off=0, imm=100000), ]) def test_local(self):