From 81a268e92e6dc323e794da6fe7b73d898aeeb5d1 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Mon, 4 Jan 2021 21:46:35 +0000
Subject: [PATCH] factor out maps from ebpf.py

---
 arraymap.py  |  76 ++++++++++++++++++++++++
 ebpf.py      | 160 ---------------------------------------------------
 ebpf_test.py |   5 +-
 hashmap.py   |  93 ++++++++++++++++++++++++++++++
 4 files changed, 172 insertions(+), 162 deletions(-)
 create mode 100644 arraymap.py
 create mode 100644 hashmap.py

diff --git a/arraymap.py b/arraymap.py
new file mode 100644
index 0000000..a7cc881
--- /dev/null
+++ b/arraymap.py
@@ -0,0 +1,76 @@
+from struct import pack, unpack
+
+from .ebpf import Map, Memory, Opcode
+from . import bpf
+
+
+class ArrayGlobalVarDesc:
+    def __init__(self, map, position, size, signed):
+        self.map = map
+        self.position = position
+        self.signed = signed
+        self.size = size
+        self.fmt = {1: "B", 2: "H", 4: "I", 8: "Q"}[size]
+        if signed:
+            self.fmt = self.fmt.lower()
+
+    def __get__(self, ebpf, owner):
+        if ebpf is None:
+            return self
+        if ebpf.loaded:
+            data = ebpf.__dict__[self.map.name].data[
+                    self.position:self.position + self.size]
+            return unpack(self.fmt, data)[0]
+        return Memory(ebpf, Memory.bits_to_opcode[self.size * 8],
+                      ebpf.r0 + self.position, self.signed)
+
+    def __set_name__(self, owner, name):
+        self.name = name
+
+    def __set__(self, ebpf, value):
+        if ebpf.loaded:
+            ebpf.__dict__[self.map.name].data[
+                    self.position:self.position + self.size] = \
+                            pack(self.fmt, value)
+        else:
+            getattr(ebpf, f"m{self.size * 8}")[ebpf.r0 + self.position] = value
+
+
+class ArrayMapAccess:
+    def __init__(self, fd, size):
+        self.fd = fd
+        self.size = size
+
+    def read(self):
+        self.data = bpf.lookup_elem(self.fd, b"\0\0\0\0", self.size)
+
+    def write(self):
+        bpf.update_elem(self.fd, b"\0\0\0\0", self.data, 0)
+
+
+class ArrayMap(Map):
+    position = 0
+
+    def __init__(self):
+        self.vars = []
+
+    def globalVar(self, signed=False, size=4):
+        ret = ArrayGlobalVarDesc(self, self.position, size, signed)
+        self.position = (self.position + 2 * size - 1) & -size
+        self.vars.append(ret)
+        return ret
+
+    def __set_name__(self, owner, name):
+        self.name = name
+
+    def init(self, ebpf):
+        fd = bpf.create_map(2, 4, self.position, 1)
+        setattr(ebpf, self.name, ArrayMapAccess(fd, self.position))
+        with ebpf.save_registers(list(range(6))), ebpf.get_stack(4) as stack:
+            ebpf.append(Opcode.ST, 10, 0, stack, 0)
+            ebpf.r1 = ebpf.get_fd(fd)
+            ebpf.r2 = ebpf.r10 + stack
+            ebpf.call(1)
+            with ebpf.If(ebpf.r0 == 0):
+                ebpf.exit()
+        ebpf.owners.add(0)
diff --git a/ebpf.py b/ebpf.py
index ccd39ed..ca33709 100644
--- a/ebpf.py
+++ b/ebpf.py
@@ -524,73 +524,6 @@ class MemoryDesc:
         return Memory(self.ebpf, self.bits, addr)
 
 
-class HashGlobalVar(Expression):
-    def __init__(self, ebpf, count, signed):
-        self.ebpf = ebpf
-        self.count = count
-        self.signed = signed
-
-    @contextmanager
-    def get_address(self, dst, long, signed, force=False):
-        if long:
-            raise AssembleError("HashMap is only for words")
-        if signed != self.signed:
-            raise AssembleError("HashMap variable has wrong signedness")
-        with self.ebpf.save_registers([i for i in range(6) if i != dst]), \
-                self.ebpf.get_stack(4) as stack:
-            self.ebpf.append(Opcode.ST, 10, 0, stack, self.count)
-            self.ebpf.r1 = self.ebpf.get_fd(self.fd)
-            self.ebpf.r2 = self.ebpf.r10 + stack
-            self.ebpf.call(1)
-            with self.ebpf.If(self.ebpf.r0 == 0):
-                self.ebpf.exit()
-            print("bla", dst, force)
-            if dst != 0 and force:
-                self.ebpf.append(Opcode.MOV + Opcode.LONG + Opcode.REG, dst, 0, 0, 0)
-            else:
-                dst = 0
-        yield dst, Opcode.W
-
-
-class HashGlobalVarDesc:
-    def __init__(self, count, signed, default=0):
-        self.count = count
-        self.signed = signed
-        self.default = default
-
-    def __get__(self, instance, owner):
-        if instance is None:
-            return self
-        if instance.loaded:
-            fd = instance.__dict__[self.name].fd
-            ret = bpf.lookup_elem(fd, pack("B", self.count), 4)
-            return unpack("i" if self.signed else "I", ret)[0]
-        ret = instance.__dict__.get(self.name, None)
-        if ret is None:
-            ret = HashGlobalVar(instance, self.count, self.signed)
-            instance.__dict__[self.name] = ret
-        return ret
-
-    def __set_name__(self, owner, name):
-        self.name = name
-
-    def __set__(self, ebpf, value):
-        if ebpf.loaded:
-            fd = ebpf.__dict__[self.name].fd
-            bpf.update_elem(fd, pack("B", self.count),
-                            pack("i" if self.signed else "I", value), 0)
-            return
-        with ebpf.save_registers([3]):
-            with value.get_address(3, False, self.signed, True):
-                with ebpf.save_registers([0, 1, 2, 4, 5]), \
-                        ebpf.get_stack(4) as stack:
-                    ebpf.r1 = ebpf.get_fd(ebpf.__dict__[self.name].fd)
-                    ebpf.append(Opcode.ST, 10, 0, stack, self.count)
-                    ebpf.r2 = ebpf.r10 + stack
-                    ebpf.r4 = 0
-                    ebpf.call(2)
-
-
 class Map:
     def init(self, ebpf):
         pass
@@ -598,99 +531,6 @@ class Map:
     def load(self, ebpf):
         pass
 
-class HashMap(Map):
-    count = 0
-
-    def __init__(self):
-        self.vars = []
-
-    def globalVar(self, signed=False, default=0):
-        self.count += 1
-        ret = HashGlobalVarDesc(self.count, signed, default)
-        self.vars.append(ret)
-        return ret
-
-    def init(self, ebpf):
-        fd = bpf.create_map(1, 1, 4, self.count)
-        for v in self.vars:
-            getattr(ebpf, v.name).fd = fd
-
-    def load(self, ebpf):
-        for v in self.vars:
-            setattr(ebpf, v.name, ebpf.__class__.__dict__[v.name].default)
-
-
-class ArrayGlobalVarDesc:
-    def __init__(self, map, position, size, signed):
-        self.map = map
-        self.position = position
-        self.signed = signed
-        self.size = size
-        self.fmt = {1: "B", 2: "H", 4: "I", 8: "Q"}[size]
-        if signed:
-            self.fmt = self.fmt.lower()
-
-    def __get__(self, ebpf, owner):
-        if ebpf is None:
-            return self
-        if ebpf.loaded:
-            data = ebpf.__dict__[self.map.name].data[
-                    self.position:self.position + self.size]
-            return unpack(self.fmt, data)[0]
-        return Memory(ebpf, Memory.bits_to_opcode[self.size * 8],
-                      ebpf.r0 + self.position, self.signed)
-
-    def __set_name__(self, owner, name):
-        self.name = name
-
-    def __set__(self, ebpf, value):
-        if ebpf.loaded:
-            ebpf.__dict__[self.map.name].data[
-                    self.position:self.position + self.size] = \
-                            pack(self.fmt, value)
-        else:
-            getattr(ebpf, f"m{self.size * 8}")[ebpf.r0 + self.position] = value
-
-
-class ArrayMapAccess:
-    def __init__(self, fd, size):
-        self.fd = fd
-        self.size = size
-
-    def read(self):
-        self.data = bpf.lookup_elem(self.fd, b"\0\0\0\0", self.size)
-
-    def write(self):
-        bpf.update_elem(self.fd, b"\0\0\0\0", self.data, 0)
-
-
-class ArrayMap(Map):
-    position = 0
-
-    def __init__(self):
-        self.vars = []
-
-    def globalVar(self, signed=False, size=4):
-        ret = ArrayGlobalVarDesc(self, self.position, size, signed)
-        self.position = (self.position + 2 * size - 1) & -size
-        self.vars.append(ret)
-        return ret
-
-    def __set_name__(self, owner, name):
-        self.name = name
-
-    def init(self, ebpf):
-        fd = bpf.create_map(2, 4, self.position, 1)
-        setattr(ebpf, self.name, ArrayMapAccess(fd, self.position))
-        with ebpf.save_registers(list(range(6))), ebpf.get_stack(4) as stack:
-            ebpf.append(Opcode.ST, 10, 0, stack, 0)
-            ebpf.r1 = ebpf.get_fd(fd)
-            ebpf.r2 = ebpf.r10 + stack
-            ebpf.call(1)
-            with ebpf.If(ebpf.r0 == 0):
-                ebpf.exit()
-        ebpf.owners.add(0)
-
 
 class PseudoFd(Expression):
     def __init__(self, ebpf, fd):
diff --git a/ebpf_test.py b/ebpf_test.py
index 90b4ef6..cede38f 100644
--- a/ebpf_test.py
+++ b/ebpf_test.py
@@ -1,9 +1,10 @@
 from unittest import TestCase, main
 
 from . import ebpf
+from .arraymap import ArrayMap
 from .ebpf import (
-    ArrayMap, AssembleError, EBPF, HashMap, Opcode, OpcodeFlags,
-    Opcode as O, LocalVar)
+    AssembleError, EBPF, Opcode, OpcodeFlags, Opcode as O, LocalVar)
+from .hashmap import HashMap
 from .xdp import XDP
 from .bpf import ProgType, prog_test_run
 
diff --git a/hashmap.py b/hashmap.py
new file mode 100644
index 0000000..2631290
--- /dev/null
+++ b/hashmap.py
@@ -0,0 +1,93 @@
+from contextlib import contextmanager
+from struct import pack, unpack, unpack
+
+from .ebpf import AssembleError, Expression, Opcode, Map
+from . import bpf
+
+
+class HashGlobalVar(Expression):
+    def __init__(self, ebpf, count, signed):
+        self.ebpf = ebpf
+        self.count = count
+        self.signed = signed
+
+    @contextmanager
+    def get_address(self, dst, long, signed, force=False):
+        if long:
+            raise AssembleError("HashMap is only for words")
+        if signed != self.signed:
+            raise AssembleError("HashMap variable has wrong signedness")
+        with self.ebpf.save_registers([i for i in range(6) if i != dst]), \
+                self.ebpf.get_stack(4) as stack:
+            self.ebpf.append(Opcode.ST, 10, 0, stack, self.count)
+            self.ebpf.r1 = self.ebpf.get_fd(self.fd)
+            self.ebpf.r2 = self.ebpf.r10 + stack
+            self.ebpf.call(1)
+            with self.ebpf.If(self.ebpf.r0 == 0):
+                self.ebpf.exit()
+            if dst != 0 and force:
+                self.ebpf.append(Opcode.MOV + Opcode.LONG + Opcode.REG, dst, 0, 0, 0)
+            else:
+                dst = 0
+        yield dst, Opcode.W
+
+
+class HashGlobalVarDesc:
+    def __init__(self, count, signed, default=0):
+        self.count = count
+        self.signed = signed
+        self.default = default
+
+    def __get__(self, instance, owner):
+        if instance is None:
+            return self
+        if instance.loaded:
+            fd = instance.__dict__[self.name].fd
+            ret = bpf.lookup_elem(fd, pack("B", self.count), 4)
+            return unpack("i" if self.signed else "I", ret)[0]
+        ret = instance.__dict__.get(self.name, None)
+        if ret is None:
+            ret = HashGlobalVar(instance, self.count, self.signed)
+            instance.__dict__[self.name] = ret
+        return ret
+
+    def __set_name__(self, owner, name):
+        self.name = name
+
+    def __set__(self, ebpf, value):
+        if ebpf.loaded:
+            fd = ebpf.__dict__[self.name].fd
+            bpf.update_elem(fd, pack("B", self.count),
+                            pack("i" if self.signed else "I", value), 0)
+            return
+        with ebpf.save_registers([3]):
+            with value.get_address(3, False, self.signed, True):
+                with ebpf.save_registers([0, 1, 2, 4, 5]), \
+                        ebpf.get_stack(4) as stack:
+                    ebpf.r1 = ebpf.get_fd(ebpf.__dict__[self.name].fd)
+                    ebpf.append(Opcode.ST, 10, 0, stack, self.count)
+                    ebpf.r2 = ebpf.r10 + stack
+                    ebpf.r4 = 0
+                    ebpf.call(2)
+
+
+class HashMap(Map):
+    count = 0
+
+    def __init__(self):
+        self.vars = []
+
+    def globalVar(self, signed=False, default=0):
+        self.count += 1
+        ret = HashGlobalVarDesc(self.count, signed, default)
+        self.vars.append(ret)
+        return ret
+
+    def init(self, ebpf):
+        fd = bpf.create_map(1, 1, 4, self.count)
+        for v in self.vars:
+            getattr(ebpf, v.name).fd = fd
+
+    def load(self, ebpf):
+        for v in self.vars:
+            setattr(ebpf, v.name, ebpf.__class__.__dict__[v.name].default)
-- 
GitLab