diff --git a/ebpf.py b/ebpf.py index 7b48072a0473ab76685c1d21d880206dc3b7a79b..ed0b5cc99d2d7a3d278a93d865f630ac960c348d 100644 --- a/ebpf.py +++ b/ebpf.py @@ -280,7 +280,9 @@ class Binary(Expression): self.operator = operator def calculate(self, dst, long, signed, force=False): - if dst is None: + orig_dst = dst + if dst is None or (not isinstance(self.right, int) + and self.right.contains(dst)): dst = self.ebpf.get_free_register() self.ebpf.owners.add(dst) free = True @@ -301,7 +303,17 @@ class Binary(Expression): dst, src, 0, 0) if rfree: self.ebpf.owners.discard(src) - return dst, long, signed, free + if orig_dst is None or orig_dst == dst: + return dst, long, signed, free + else: + self.ebpf.append(Opcode.MOV + Opcode.LONG * long, orig_dst, dst, + 0, 0) + self.ebpf.owners.discard(dst) + return orig_dst, long, signed, False + + def contains(self, no): + return self.left.contains(no) or (not isinstance(self.right, int) + and self.right.contains(no)) class Sum(Binary): @@ -392,6 +404,9 @@ class Register(Expression): else: return self.no, self.long, self.signed, False + def contains(self, no): + return self.no == no + class Memory(Expression): def __init__(self, ebpf, bits, address): @@ -417,6 +432,9 @@ class Memory(Expression): self.ebpf.owners.discard(src) return dst, long, signed, free + def contains(self, no): + return self.address.contains(no) + class MemoryDesc: def __init__(self, ebpf, bits): diff --git a/ebpf_test.py b/ebpf_test.py index aaf03466d361814d443e1aa4e69bdd93be678c06..015e4fb68210e132c05ae514cd46d905ff5600bc 100644 --- a/ebpf_test.py +++ b/ebpf_test.py @@ -382,6 +382,7 @@ class Tests(TestCase): e.r5 = e.m16[e.r10 + e.r3] e.r0 = (e.r1 * e.r3) - (e.r10 * e.r5) e.r5 = (e.r1 * e.r3) + e.m32[e.r10 + e.r0] + e.r5 = e.r3 + e.r5 self.assertEqual(e.opcodes, [ Instruction(opcode=191, dst=3, src=1, off=0, imm=0), Instruction(opcode=191, dst=0, src=10, off=0, imm=0), @@ -408,7 +409,10 @@ class Tests(TestCase): Instruction(opcode=191, dst=2, src=10, off=0, imm=0), Instruction(opcode=15, dst=2, src=0, off=0, imm=0), Instruction(opcode=97, dst=2, src=2, off=0, imm=0), - Instruction(opcode=15, dst=5, src=2, off=0, imm=0)]) + Instruction(opcode=15, dst=5, src=2, off=0, imm=0), + Instruction(opcode=O.LONG+O.MOV+O.REG, dst=2, src=3, off=0, imm=0), + Instruction(opcode=O.LONG+O.ADD+O.REG, dst=2, src=5, off=0, imm=0), + Instruction(opcode=O.LONG+O.MOV, dst=5, src=2, off=0, imm=0)]) with self.assertRaises(AssembleError): e.r8 = e.r2