From 69d7feb651c4c3443665fb271207ec6364bc7250 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@gmail.com>
Date: Sat, 25 Feb 2023 15:53:45 +0000
Subject: [PATCH] fix setting bits from data

---
 ebpfcat/ebpf.py      | 62 +++++++++++++++++++++++---------------------
 ebpfcat/ebpf_test.py | 10 +++++++
 2 files changed, 42 insertions(+), 30 deletions(-)

diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py
index b74a64e..106f0be 100644
--- a/ebpfcat/ebpf.py
+++ b/ebpfcat/ebpf.py
@@ -968,38 +968,40 @@ class Memory(Expression):
 
     def _set(self, value):
         opcode = Opcode.STX
-        if isinstance(self.fmt, tuple):
-            pos, bits = self.fmt
-            self.fmt = "B"
-            if bits == 1:
-                try:
-                    if value:
-                        value = self | (1 << pos)
-                    else:
-                        value = self & ~(1 << pos)
-                except AssembleError:
-                    with ebpf.wtmp:
+        with ExitStack() as exitStack:
+            if isinstance(self.fmt, tuple):
+                pos, bits = self.fmt
+                self.fmt = "B"
+                if bits == 1:
+                    try:
+                        if value:
+                            value = self | (1 << pos)
+                        else:
+                            value = self & ~(1 << pos)
+                    except AssembleError:
+                        exitStack.enter_context(self.ebpf.wtmp)
                         with value as Else:
-                            ebpf.wtmp = self | (1 << pos)
+                            self.ebpf.wtmp = self | (1 << pos)
                         with Else:
-                            ebpf.wtmp = self & ~(1 << pos)
-            else:
-                mask = ((1 << bits) - 1) << pos
-                value = (mask & (value << pos) | ~mask & self)
-        elif isinstance(value, IAdd) and len(self.fmt) == 1:
-            value = value.value
-            opcode = Opcode.XADD
-        elif not isinstance(value, Expression):
-            if self.fmt == "x":
-                value = Constant(self.ebpf, value)
-            else:
-                value = Constant(self.ebpf,
-                                 *unpack(self.fmt, pack(self.fmt[-1], value)))
-        if self.fmt == "x" and not value.fixed:
-            value *= Expression.FIXED_BASE
-        elif self.fmt != "x" and value.fixed:
-            value /= Expression.FIXED_BASE
-        with ExitStack() as exitStack:
+                            self.ebpf.wtmp = self & ~(1 << pos)
+                        value = self.ebpf.wtmp
+                else:
+                    mask = ((1 << bits) - 1) << pos
+                    value = (mask & (value << pos) | ~mask & self)
+            elif isinstance(value, IAdd) and len(self.fmt) == 1:
+                value = value.value
+                opcode = Opcode.XADD
+            elif not isinstance(value, Expression):
+                if self.fmt == "x":
+                    value = Constant(self.ebpf, value)
+                else:
+                    value = Constant(
+                        self.ebpf,
+                        *unpack(self.fmt, pack(self.fmt[-1], value)))
+            if self.fmt == "x" and not value.fixed:
+                value *= Expression.FIXED_BASE
+            elif self.fmt != "x" and value.fixed:
+                value /= Expression.FIXED_BASE
             if isinstance(self.address, Sum):
                 dst = self.address.left.no
                 offset = self.address.right.value
diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py
index aa81111..01b61ac 100644
--- a/ebpfcat/ebpf_test.py
+++ b/ebpfcat/ebpf_test.py
@@ -316,6 +316,8 @@ class Tests(TestCase):
         with e.b:
             e.a = 0
 
+        e.a = e.b
+
         self.assertEqual(e.opcodes, [
            Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0),
             Instruction(opcode=O.JSET, dst=0, src=0, off=1, imm=32),
@@ -343,6 +345,14 @@ class Tests(TestCase):
             Instruction(opcode=O.JMP, dst=0, src=0, off=3, imm=0),
             Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0),
             Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=-33),
+            Instruction(opcode=O.STX+O.B, dst=10, src=0, off=-1, imm=0),
+            Instruction(opcode=O.LD+O.B, dst=2, src=10, off=-2, imm=0),
+            Instruction(opcode=O.JSET, dst=2, src=0, off=3, imm=120),
+            Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0),
+            Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=-33),
+            Instruction(opcode=O.JMP, dst=0, src=0, off=2, imm=0),
+            Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0),
+            Instruction(opcode=O.OR, dst=0, src=0, off=0, imm=32),
             Instruction(opcode=O.B+O.STX, dst=10, src=0, off=-1, imm=0)])
 
     def test_local_subprog(self):
-- 
GitLab