diff --git a/ebpf.py b/ebpf.py index 19c900c27d818f9908f897e6409a15c095802fe5..d7d6b283d3af201009fc68c896755aae2f20be6b 100644 --- a/ebpf.py +++ b/ebpf.py @@ -209,6 +209,7 @@ class Opcode(Enum): LD = 0x61 ST = 0x62 STX = 0x63 + XADD = 0xc3 def __mul__(self, value): if value: @@ -597,6 +598,11 @@ class Register(Expression): return self.no == no +class IAdd: + def __init__(self, value): + self.value = value + + class Memory(Expression): bits_to_opcode = {32: Opcode.W, 16: Opcode.H, 8: Opcode.B, 64: Opcode.DW} @@ -606,6 +612,12 @@ class Memory(Expression): self.address = address self.signed = signed + def __iadd__(self, value): + return IAdd(value) + + def __isub__(self, value): + return IAdd(-value) + @contextmanager def calculate(self, dst, long, signed, force=False): if not long and self.bits == Opcode.DW: @@ -653,11 +665,22 @@ class MemoryDesc: if isinstance(value, int): ebpf.append(Opcode.ST + bits, self.base_register, 0, self.addr(instance), value) + return + elif isinstance(value, IAdd): + value = value.value + if isinstance(value, int): + with ebpf.get_free_register(None) as src: + ebpf.r[src] = value + ebpf.append(Opcode.XADD + bits, self.base_register, + src, self.addr(instance), 0) + return + opcode = Opcode.XADD else: - with value.calculate(None, self.bits == 64, self.signed) \ - as (src, _, _): - ebpf.append(Opcode.STX + bits, self.base_register, - src, self.addr(instance), 0) + opcode = Opcode.STX + with value.calculate(None, self.bits == 64, self.signed) \ + as (src, _, _): + ebpf.append(opcode + bits, self.base_register, + src, self.addr(instance), 0) class LocalVar(MemoryDesc): @@ -687,16 +710,26 @@ class MemoryMap: if isinstance(addr, Sum): dst = addr.left.no offset = addr.right - afree = False else: dst, _, _ = exitStack.enter_context( addr.calculate(None, None, None)) offset = 0 if isinstance(value, int): self.ebpf.append(Opcode.ST + self.bits, dst, 0, offset, value) + return + elif isinstance(value, IAdd): + value = value.value + if isinstance(value, int): + with self.ebpf.get_free_register(None) as src: + self.ebpf.r[src] = value + self.ebpf.append( + Opcode.XADD + self.bits, dst, src, offset, 0) + return + opcode = Opcode.XADD else: - with value.calculate(None, None, None) as (src, _, _): - self.ebpf.append(Opcode.STX+self.bits, dst, src, offset, 0) + opcode = Opcode.STX + with value.calculate(None, None, None) as (src, _, _): + self.ebpf.append(opcode + self.bits, dst, src, offset, 0) def __getitem__(self, addr): if isinstance(addr, Register): diff --git a/ebpf_test.py b/ebpf_test.py index b2ac19ee6c1f7350b3f26ca66d1f27df47626cba..b3946250abfad852dd3877c225f6448f511dd4a8 100644 --- a/ebpf_test.py +++ b/ebpf_test.py @@ -180,6 +180,22 @@ class Tests(TestCase): Instruction(opcode=O.W+O.LD, dst=3, src=10, off=-12, imm=0), Instruction(opcode=O.W+O.ST, dst=10, src=0, off=-12, imm=7)]) + def test_lock_add(self): + class Local(EBPF): + a = LocalVar(32, False) + + e = Local(ProgType.XDP, "GPL") + e.a += 3 + e.m32[e.r1] += e.r1 + e.a -= 3 + + self.assertEqual(e.opcodes, [ + Instruction(opcode=O.LONG+O.MOV, dst=0, src=0, off=0, imm=3), + Instruction(opcode=O.XADD+O.W, dst=10, src=0, off=-4, imm=0), + Instruction(opcode=O.XADD+O.W, dst=1, src=1, off=0, imm=0), + Instruction(opcode=O.LONG+O.MOV, dst=0, src=0, off=0, imm=-3), + Instruction(opcode=O.XADD+O.W, dst=10, src=0, off=-4, imm=0)]) + def test_jump(self): e = EBPF()