diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py index 1fb6969b2a31330540cc87dfc1176a24d594fda3..a26e13f21c33e460e64e462ed067a44477001c43 100644 --- a/ebpfcat/ebpf.py +++ b/ebpfcat/ebpf.py @@ -333,14 +333,14 @@ class SimpleComparison(Comparison): def compare(self, negative): with self.left.calculate(None, None, None) as (self.dst, _, lsigned): with ExitStack() as exitStack: - if not isinstance(self.right, int): + if isinstance(self.right, int): + rsigned = (self.right < 0) + else: self.src, _, rsigned = exitStack.enter_context( self.right.calculate(None, None, None)) - if rsigned != rsigned: - raise AssembleError("need same signedness for comparison") self.origin = len(self.ebpf.opcodes) self.ebpf.opcodes.append(None) - self.opcode = self.opcode[negative + 2 * lsigned] + self.opcode = self.opcode[negative + 2 * (lsigned or rsigned)] self.owners = self.ebpf.owners.copy() def target(self): @@ -476,28 +476,29 @@ class Binary(Expression): dst = None with self.ebpf.get_free_register(dst) as dst: with self.left.calculate(dst, long, signed, True) \ - as (dst, l_long, signed): - pass + as (dst, l_long, l_signed): + signed = signed or l_signed if self.operator is Opcode.RSH and signed: # >>= operator = Opcode.ARSH else: operator = self.operator if isinstance(self.right, int): + r_signed = self.right < 0 self.ebpf.append(operator + Opcode.LONG * long, dst, 0, 0, self.right) else: with self.right.calculate(None, long, None) as \ - (src, r_long, _): + (src, r_long, r_signed): self.ebpf.append( operator + Opcode.REG + Opcode.LONG * ((r_long or l_long) if long is None else long), dst, src, 0, 0) if orig_dst is None or orig_dst == dst: - yield dst, long, signed + yield dst, long, signed or r_signed return self.ebpf.append(Opcode.MOV + Opcode.REG + Opcode.LONG * long, orig_dst, dst, 0, 0) - yield orig_dst, long, signed + yield orig_dst, long, signed or r_signed def contains(self, no): return self.left.contains(no) or (not isinstance(self.right, int) diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py index f85dbb85063672a8f5cc5c0a32b6bf08a2e43229..2f772216ca92151842abb71be18693a7d44d6229 100644 --- a/ebpfcat/ebpf_test.py +++ b/ebpfcat/ebpf_test.py @@ -414,6 +414,19 @@ class Tests(TestCase): Instruction(opcode=O.MOV+O.REG, dst=1, src=2, off=0, imm=0), Instruction(opcode=O.REG+O.ADD, dst=1, src=3, off=0, imm=0)]) + def test_mixed_compare(self): + e = EBPF() + e.owners = {0, 1, 2, 3} + with e.r1 > e.sr2: + pass + with (e.r1 + e.sr2) > 3: + pass + self.assertEqual(e.opcodes, [ + Instruction(opcode=O.JSLE+O.REG, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.MOV+O.LONG+O.REG, dst=4, src=1, off=0, imm=0), + Instruction(opcode=O.ADD+O.LONG+O.REG, dst=4, src=2, off=0, imm=0), + Instruction(opcode=O.JSLE, dst=4, src=0, off=0, imm=3)]) + def test_reverse_binary(self): e = EBPF()