From e925aaaa2be13aa7c43f6515c4678eccfef65e85 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Tue, 29 Dec 2020 16:07:10 +0000
Subject: [PATCH] more generalized memory access

---
 ebpf.py | 37 +++++++++++++++++++++++--------------
 1 file changed, 23 insertions(+), 14 deletions(-)

diff --git a/ebpf.py b/ebpf.py
index 0705b47..b904b87 100644
--- a/ebpf.py
+++ b/ebpf.py
@@ -271,12 +271,12 @@ class Expression:
     @contextmanager
     def calculate(self, dst, long, signed, force=False):
         with self.ebpf.get_free_register(dst) as dst:
-            with self.get_address(dst, long, signed):
-                self.ebpf.append(Opcode.LD, dst, dst, 0, 0)
-                yield dst, False, self.signed
+            with self.get_address(dst, long, signed) as (src, bits):
+                self.ebpf.append(Opcode.LD + bits, dst, src, 0, 0)
+                yield dst, long, self.signed
 
     @contextmanager
-    def get_address(self, dst, long, signed):
+    def get_address(self, dst, long, signed, force=False):
         with self.ebpf.get_stack(4 + 4 * long) as stack:
             with self.calculate(dst, long, signed) as (src, _, _):
                 self.ebpf.append(Opcode.STX + Opcode.DW * long,
@@ -451,14 +451,19 @@ class Memory(Expression):
     def calculate(self, dst, long, signed, force=False):
         if not long and self.bits == Opcode.DW:
             raise AssembleError("cannot compile")
-        with self.ebpf.get_free_register(dst) as dst:
-            if isinstance(self.address, Sum):
+        if isinstance(self.address, Sum):
+            with self.ebpf.get_free_register(dst) as dst:
                 self.ebpf.append(Opcode.LD + self.bits, dst,
                                  self.address.left.no, self.address.right, 0)
-            else:
-                with self.address.calculate(dst, None, None) as (src, _, _):
-                    self.ebpf.append(Opcode.LD + self.bits, dst, src, 0, 0)
-            yield dst, long, self.signed
+                yield dst, long, self.signed
+        else:
+            with super().calculate(dst, long, signed, force) as ret:
+                yield ret
+
+    @contextmanager
+    def get_address(self, dst, long, signed, force=False):
+        with self.address.calculate(dst, None, None) as (src, _, _):
+            yield src, self.bits
 
     def contains(self, no):
         return self.address.contains(no)
@@ -526,7 +531,7 @@ class HashGlobalVar(Expression):
         self.signed = signed
 
     @contextmanager
-    def get_address(self, dst, long, signed):
+    def get_address(self, dst, long, signed, force=False):
         if long:
             raise AssembleError("HashMap is only for words")
         if signed != self.signed:
@@ -539,8 +544,12 @@ class HashGlobalVar(Expression):
             self.ebpf.call(1)
             with self.ebpf.If(self.ebpf.r0 == 0):
                 self.ebpf.exit()
-            self.ebpf.append(Opcode.MOV + Opcode.LONG + Opcode.REG, dst, 0, 0, 0)
-        yield
+            print("bla", dst, force)
+            if dst != 0 and force:
+                self.ebpf.append(Opcode.MOV + Opcode.LONG + Opcode.REG, dst, 0, 0, 0)
+            else:
+                dst = 0
+        yield dst, Opcode.W
 
 
 class HashGlobalVarDesc:
@@ -572,7 +581,7 @@ class HashGlobalVarDesc:
                             pack("i" if self.signed else "I", value), 0)
             return
         with ebpf.save_registers([3]):
-            with value.get_address(3, False, self.signed):
+            with value.get_address(3, False, self.signed, True):
                 with ebpf.save_registers([0, 1, 2, 4, 5]), \
                         ebpf.get_stack(4) as stack:
                     ebpf.r1 = ebpf.get_fd(ebpf.__dict__[self.name].fd)
-- 
GitLab