From 698ef42c6c295e9182c3b2a72d02a3bc5e9f939a Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@gmail.com>
Date: Sat, 25 Feb 2023 11:26:59 +0000
Subject: [PATCH] unify setting to memory

this modifies the tests to do some multiplication not long -
but it actually seems legit.
---
 ebpfcat/ebpf.py      | 136 +++++++++++++++++++------------------------
 ebpfcat/ebpf_test.py |   4 +-
 2 files changed, 62 insertions(+), 78 deletions(-)

diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py
index f1949e6..b74a64e 100644
--- a/ebpfcat/ebpf.py
+++ b/ebpfcat/ebpf.py
@@ -966,6 +966,61 @@ class Memory(Expression):
             return Memory(self.ebpf, "B", self.address) & mask != 0
         return super().__ne__(value)
 
+    def _set(self, value):
+        opcode = Opcode.STX
+        if isinstance(self.fmt, tuple):
+            pos, bits = self.fmt
+            self.fmt = "B"
+            if bits == 1:
+                try:
+                    if value:
+                        value = self | (1 << pos)
+                    else:
+                        value = self & ~(1 << pos)
+                except AssembleError:
+                    with ebpf.wtmp:
+                        with value as Else:
+                            ebpf.wtmp = self | (1 << pos)
+                        with Else:
+                            ebpf.wtmp = self & ~(1 << pos)
+            else:
+                mask = ((1 << bits) - 1) << pos
+                value = (mask & (value << pos) | ~mask & self)
+        elif isinstance(value, IAdd) and len(self.fmt) == 1:
+            value = value.value
+            opcode = Opcode.XADD
+        elif not isinstance(value, Expression):
+            if self.fmt == "x":
+                value = Constant(self.ebpf, value)
+            else:
+                value = Constant(self.ebpf,
+                                 *unpack(self.fmt, pack(self.fmt[-1], value)))
+        if self.fmt == "x" and not value.fixed:
+            value *= Expression.FIXED_BASE
+        elif self.fmt != "x" and value.fixed:
+            value /= Expression.FIXED_BASE
+        with ExitStack() as exitStack:
+            if isinstance(self.address, Sum):
+                dst = self.address.left.no
+                offset = self.address.right.value
+            else:
+                dst, _ = exitStack.enter_context(
+                    self.address.calculate(None, True))
+                offset = 0
+            if isinstance(value, Constant) and value.small \
+                    and opcode == Opcode.STX:
+                self.ebpf.append(Opcode.ST + fmt_to_opcode(self.fmt), dst, 0,
+                                 offset, int(value.value))
+                return
+            src, _ = exitStack.enter_context(
+                value.calculate(None, isinstance(self.fmt, str)
+                                      and self.fmt[-1] in 'qQx'))
+            if not isinstance(value, Constant):
+                self.ebpf.append_endian(self.fmt, src)
+            self.ebpf.append(opcode + fmt_to_opcode(self.fmt),
+                             dst, src, offset, 0)
+
+
 
 class MemoryDesc:
     """A base class used by descriptors for memory
@@ -985,49 +1040,10 @@ class MemoryDesc:
                       instance.ebpf.r[self.base_register] + addr)
 
     def __set__(self, instance, value):
-        ebpf = instance.ebpf
         fmt, addr = self.fmt_addr(instance)
-        opcode = Opcode.STX
-        if isinstance(fmt, tuple):
-            before = Memory(ebpf, "B", ebpf.r[self.base_register] + addr)
-            if fmt[1] == 1:
-                try:
-                    if value:
-                        value = before | (1 << fmt[0])
-                    else:
-                        value = before & ~(1 << fmt[0])
-                except AssembleError:
-                    with ebpf.wtmp:
-                        with value as Else:
-                            ebpf.wtmp = before | (1 << fmt[0])
-                        with Else:
-                            ebpf.wtmp = before & ~(1 << fmt[0])
-            else:
-                mask = ((1 << fmt[1]) - 1) << fmt[0]
-                value = (mask & (value << self.fmt[0]) | ~mask & before)
-        elif isinstance(value, IAdd) and len(fmt) == 1:
-            value = value.value
-            opcode = Opcode.XADD
-        elif not isinstance(value, Expression):
-            if fmt == "x":
-                value = Constant(ebpf, value)
-            else:
-                value = Constant(ebpf, *unpack(fmt, pack(fmt[-1], value)))
-        if self.fmt == "x" and not value.fixed:
-            value *= Expression.FIXED_BASE
-        elif self.fmt != "x" and value.fixed:
-            value /= Expression.FIXED_BASE
-        if isinstance(value, Constant) and value.small and opcode == Opcode.STX:
-            ebpf.append(Opcode.ST + fmt_to_opcode(fmt), self.base_register, 0,
-                    addr, int(value.value))
-            return
-        with value.calculate(None, isinstance(fmt, str) and fmt[-1] in 'qQx'
-                            ) as (src, _):
-            if not isinstance(value, Constant):
-                ebpf.append_endian(fmt, src)
-            ebpf.append(opcode + fmt_to_opcode(fmt), self.base_register,
-                        src, addr, 0)
-
+        memory = Memory(instance.ebpf, fmt,
+                        instance.ebpf.r[self.base_register] + addr)
+        memory._set(value)
 
 class LocalVar(MemoryDesc):
     """variables on the stack"""
@@ -1062,40 +1078,8 @@ class MemoryMap:
         self.fmt = fmt
 
     def __setitem__(self, addr, value):
-        with ExitStack() as exitStack:
-            if isinstance(addr, Sum):
-                dst = addr.left.no
-                offset = addr.right.value
-            else:
-                dst, _ = exitStack.enter_context(addr.calculate(None, True))
-                offset = 0
-            if isinstance(value, IAdd):
-                value = value.value
-                if self.fmt == "x":
-                    value = int(value * self.FIXED_BASE)
-                if not isinstance(value, Expression):
-                    with self.ebpf.get_free_register(None) as src:
-                        self.ebpf.r[src] = value
-                        self.ebpf.append(
-                            Opcode.XADD + fmt_to_opcode(self.fmt),
-                            dst, src, offset, 0)
-                    return
-                opcode = Opcode.XADD
-            elif isinstance(value, Expression):
-                opcode = Opcode.STX
-            else:
-                value = Constant(self.ebpf, value)
-                if value.small:
-                    if self.fmt == "x":
-                        value *= self.FIXED_BASE
-                    self.ebpf.append(Opcode.ST + fmt_to_opcode(self.fmt), dst, 0,
-                                     offset, int(value.value))
-                    return
-            with value.calculate(None, None) as (src, _):
-                if not isinstance(value, Constant):
-                    self.ebpf.append_endian(self.fmt, src)
-                self.ebpf.append(opcode + fmt_to_opcode(self.fmt),
-                                 dst, src, offset, 0)
+        memory = Memory(self.ebpf, self.fmt, addr)
+        memory._set(value)
 
     def __getitem__(self, addr):
         if isinstance(addr, Register):
diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py
index 0b36f02..aa81111 100644
--- a/ebpfcat/ebpf_test.py
+++ b/ebpfcat/ebpf_test.py
@@ -810,12 +810,12 @@ class Tests(TestCase):
             Instruction(opcode=39, dst=0, src=0, off=0, imm=2),
             Instruction(opcode=31, dst=3, src=0, off=0, imm=0),
             Instruction(opcode=191, dst=0, src=3, off=0, imm=0),
-            Instruction(opcode=O.MUL+O.LONG, dst=0, src=0, off=0, imm=2),
+            Instruction(opcode=O.MUL, dst=0, src=0, off=0, imm=2),
             Instruction(opcode=107, dst=10, src=0, off=-10, imm=0),
             Instruction(opcode=191, dst=0, src=10, off=0, imm=0),
             Instruction(opcode=15, dst=0, src=3, off=0, imm=0),
             Instruction(opcode=191, dst=2, src=3, off=0, imm=0),
-            Instruction(opcode=O.MUL+O.LONG, dst=2, src=0, off=0, imm=2),
+            Instruction(opcode=O.MUL, dst=2, src=0, off=0, imm=2),
             Instruction(opcode=107, dst=0, src=2, off=0, imm=0),
 
             Instruction(opcode=191, dst=5, src=10, off=0, imm=0),
-- 
GitLab