From 85e8c6987ac5cfb6161f2c6110c22f5a14642248 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@gmail.com>
Date: Sun, 26 Feb 2023 10:00:02 +0000
Subject: [PATCH] support in-place addition for fixed values

---
 ebpfcat/ebpf.py      |  6 +++---
 ebpfcat/ebpf_test.py | 12 +++++++++++-
 2 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py
index cdceca5..933ace8 100644
--- a/ebpfcat/ebpf.py
+++ b/ebpfcat/ebpf.py
@@ -867,13 +867,13 @@ class Memory(Expression):
         self.address = address
 
     def __iadd__(self, value):
-        if self.fmt in "qQiI":
+        if self.fmt in "qQiIx":
             return IAdd(self.ebpf, value)
         else:
             return NotImplemented
 
     def __isub__(self, value):
-        if self.fmt in "qQiI":
+        if self.fmt in "qQiIx":
             return IAdd(self.ebpf, -value)
         else:
             return NotImplemented
@@ -948,7 +948,7 @@ class Memory(Expression):
                 else:
                     mask = ((1 << bits) - 1) << pos
                     value = (mask & (value << pos) | ~mask & self)
-            elif isinstance(value, IAdd) and len(self.fmt) == 1:
+            elif isinstance(value, IAdd):
                 value = value.value
                 opcode = Opcode.XADD
             elif not isinstance(value, Expression):
diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py
index 3ee92bb..dbe2cdb 100644
--- a/ebpfcat/ebpf_test.py
+++ b/ebpfcat/ebpf_test.py
@@ -417,6 +417,7 @@ class Tests(TestCase):
             a = LocalVar('I')
             b = LocalVar('q')
             c = LocalVar('h')
+            d = LocalVar('x')
 
         e = Local(ProgType.XDP, "GPL")
         e.a += 3
@@ -429,6 +430,9 @@ class Tests(TestCase):
         e.c += 3
         e.mB[e.r1] += e.r1
 
+        e.d -= 5
+        e.d += e.r1
+
         self.assertEqual(e.opcodes, [
            Instruction(opcode=O.LONG+O.MOV, dst=0, src=0, off=0, imm=3),
            Instruction(opcode=O.XADD+O.W, dst=10, src=0, off=-4, imm=0),
@@ -443,7 +447,13 @@ class Tests(TestCase):
            Instruction(opcode=O.STX+O.REG, dst=10, src=0, off=-18, imm=0),
            Instruction(opcode=O.B+O.LD, dst=0, src=1, off=0, imm=0),
            Instruction(opcode=O.ADD+O.REG, dst=0, src=1, off=0, imm=0),
-           Instruction(opcode=O.STX+O.B, dst=1, src=0, off=0, imm=0)])
+           Instruction(opcode=O.STX+O.B, dst=1, src=0, off=0, imm=0),
+           Instruction(opcode=O.LONG+O.MOV, dst=0, src=0, off=0, imm=-500000),
+           Instruction(opcode=O.XADD+O.DW, dst=10, src=0, off=-32, imm=0),
+           Instruction(opcode=O.REG+O.LONG+O.MOV, dst=0, src=1, off=0, imm=0),
+           Instruction(opcode=O.MUL+O.LONG, dst=0, src=0, off=0, imm=100000),
+           Instruction(opcode=O.XADD+O.DW, dst=10, src=0, off=-32, imm=0),
+        ])
 
 
     def test_jump(self):
-- 
GitLab