diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py index 5c227ac441b4c2c58c9b4f945fa642ae158734fb..7d3df0cb91ae3119a7c1be76a7a819281bbd0e18 100644 --- a/ebpfcat/ebpf.py +++ b/ebpfcat/ebpf.py @@ -284,9 +284,9 @@ class Comparison: self.ebpf.opcodes[self.origin] = Instruction(op, dst, src, off+1, imm) def Else(self): - self.retarget_one() self.else_origin = len(self.ebpf.opcodes) self.ebpf.opcodes.append(None) + self.target(True) return self def __and__(self, value): @@ -322,8 +322,8 @@ class SimpleComparison(Comparison): self.opcode = self.opcode[negative + 2 * (lsigned or rsigned)] self.owners = self.ebpf.owners.copy() - def target(self): - assert self.ebpf.opcodes[self.origin] is None + def target(self, retarget=False): + assert retarget or self.ebpf.opcodes[self.origin] is None if self.opcode == Opcode.JMP: inst = Instruction(Opcode.JMP, 0, 0, len(self.ebpf.opcodes) - self.origin - 1, 0) @@ -336,8 +336,9 @@ class SimpleComparison(Comparison): self.opcode + Opcode.REG, self.dst, self.src, len(self.ebpf.opcodes) - self.origin - 1, 0) self.ebpf.opcodes[self.origin] = inst - self.ebpf.owners, self.owners = \ - self.ebpf.owners & self.owners, self.ebpf.owners + if not retarget: + self.ebpf.owners, self.owners = \ + self.ebpf.owners & self.owners, self.ebpf.owners class AndOrComparison(Comparison): @@ -357,17 +358,10 @@ class AndOrComparison(Comparison): self.left.target() self.owners = self.ebpf.owners.copy() - def target(self): + def target(self, retarget=False): if self.is_and == self.negative: - self.left.target() - self.right.target() - - def Else(self): - self.left.retarget_one() - self.right.retarget_one() - self.else_origin = len(self.ebpf.opcodes) - self.ebpf.opcodes.append(None) - return self + self.left.target(retarget) + self.right.target(retarget) class InvertComparison(Comparison): diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py index 3e41544cc41ac92feee095fa12b62ba9776f1a19..a62a5a9f4403465adebfb541070f15c8129e73d6 100644 --- a/ebpfcat/ebpf_test.py +++ b/ebpfcat/ebpf_test.py @@ -361,7 +361,17 @@ class Tests(TestCase): e.r3 = 7 e.r4 = 3 self.maxDiff = None - self.assertEqual(e.opcodes, []) + self.assertEqual(e.opcodes, [ + Instruction(opcode=O.JGT, dst=2, src=0, off=1, imm=3), + Instruction(opcode=O.JLE, dst=3, src=0, off=1, imm=2), + Instruction(opcode=O.MOV+O.LONG, dst=1, src=0, off=0, imm=5), + Instruction(opcode=O.JGT, dst=2, src=0, off=1, imm=2), + Instruction(opcode=O.JLE, dst=1, src=0, off=3, imm=2), + Instruction(opcode=O.MOV+O.LONG, dst=2, src=0, off=0, imm=5), + Instruction(opcode=O.MOV+O.LONG, dst=5, src=0, off=0, imm=4), + Instruction(opcode=O.JMP, dst=0, src=0, off=2, imm=0), + Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=7), + Instruction(opcode=O.MOV+O.LONG, dst=4, src=0, off=0, imm=3)]) def test_comp_binary(self): e = EBPF()