diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py index cdceca590bb6c5cedac4fb9d21f5d606fcd28242..933ace8000276d125b4c60faddd8dcb8bf96660f 100644 --- a/ebpfcat/ebpf.py +++ b/ebpfcat/ebpf.py @@ -867,13 +867,13 @@ class Memory(Expression): self.address = address def __iadd__(self, value): - if self.fmt in "qQiI": + if self.fmt in "qQiIx": return IAdd(self.ebpf, value) else: return NotImplemented def __isub__(self, value): - if self.fmt in "qQiI": + if self.fmt in "qQiIx": return IAdd(self.ebpf, -value) else: return NotImplemented @@ -948,7 +948,7 @@ class Memory(Expression): else: mask = ((1 << bits) - 1) << pos value = (mask & (value << pos) | ~mask & self) - elif isinstance(value, IAdd) and len(self.fmt) == 1: + elif isinstance(value, IAdd): value = value.value opcode = Opcode.XADD elif not isinstance(value, Expression): diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py index 3ee92bbdad6b3403e58c0a58b6ec07129f91f91e..dbe2cdb8fe64d83df298a824d8fbfdb2dba1ab71 100644 --- a/ebpfcat/ebpf_test.py +++ b/ebpfcat/ebpf_test.py @@ -417,6 +417,7 @@ class Tests(TestCase): a = LocalVar('I') b = LocalVar('q') c = LocalVar('h') + d = LocalVar('x') e = Local(ProgType.XDP, "GPL") e.a += 3 @@ -429,6 +430,9 @@ class Tests(TestCase): e.c += 3 e.mB[e.r1] += e.r1 + e.d -= 5 + e.d += e.r1 + self.assertEqual(e.opcodes, [ Instruction(opcode=O.LONG+O.MOV, dst=0, src=0, off=0, imm=3), Instruction(opcode=O.XADD+O.W, dst=10, src=0, off=-4, imm=0), @@ -443,7 +447,13 @@ class Tests(TestCase): Instruction(opcode=O.STX+O.REG, dst=10, src=0, off=-18, imm=0), Instruction(opcode=O.B+O.LD, dst=0, src=1, off=0, imm=0), Instruction(opcode=O.ADD+O.REG, dst=0, src=1, off=0, imm=0), - Instruction(opcode=O.STX+O.B, dst=1, src=0, off=0, imm=0)]) + Instruction(opcode=O.STX+O.B, dst=1, src=0, off=0, imm=0), + Instruction(opcode=O.LONG+O.MOV, dst=0, src=0, off=0, imm=-500000), + Instruction(opcode=O.XADD+O.DW, dst=10, src=0, off=-32, imm=0), + Instruction(opcode=O.REG+O.LONG+O.MOV, dst=0, src=1, off=0, imm=0), + Instruction(opcode=O.MUL+O.LONG, dst=0, src=0, off=0, imm=100000), + Instruction(opcode=O.XADD+O.DW, dst=10, src=0, off=-32, imm=0), + ]) def test_jump(self):