From 04df2a4b1c14d334badc17e7788434cd6354ea98 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Tue, 29 Dec 2020 13:45:18 +0000
Subject: [PATCH] make the array map work

---
 bpf.py       |  7 +++--
 ebpf.py      | 80 +++++++++++++++++++++++++++++++++++++++++++++++++---
 ebpf_test.py | 21 +++++++++++++-
 3 files changed, 101 insertions(+), 7 deletions(-)

diff --git a/bpf.py b/bpf.py
index f7057e4..fb4eea4 100644
--- a/bpf.py
+++ b/bpf.py
@@ -50,7 +50,7 @@ def lookup_elem(fd, key, size):
     value = create_string_buffer(size)
     ret, _ = bpf(1, "IQQQ", fd, addrof(key), addrof(value), 0)
     if ret == 0:
-        return value.raw
+        return value
     else:
         return None
 
@@ -100,4 +100,7 @@ def prog_test_run(fd, data_in, data_out, ctx_in, ctx_out,
 if __name__ == "__main__":
     fd = create_map(1, 4, 4, 10)
     update_elem(fd, b"asdf", b"ckde", 0)
-    print(lookup_elem(fd, b"asdf", 4))
+    ret = lookup_elem(fd, b"asdf", 4)
+    ret[2:4] = b"kk"
+    update_elem(fd, b"asdf", ret, 0)
+    print(lookup_elem(fd, b"asdf", 4).raw)
diff --git a/ebpf.py b/ebpf.py
index 9238663..3f4e9cf 100644
--- a/ebpf.py
+++ b/ebpf.py
@@ -422,6 +422,8 @@ class Register(Expression):
 
 
 class Memory(Expression):
+    bits_to_opcode = {32: Opcode.W, 16: Opcode.H, 8: Opcode.B, 64: Opcode.DW}
+
     def __init__(self, ebpf, bits, address, signed=False):
         self.ebpf = ebpf
         self.bits = bits
@@ -446,8 +448,6 @@ class Memory(Expression):
 
 
 class LocalVar:
-    bits_to_opcode = {32: Opcode.W, 16: Opcode.H, 8: Opcode.B, 64: Opcode.DW}
-
     def __init__(self, bits=32, signed=False):
         self.bits = bits
         self.signed = signed
@@ -463,11 +463,11 @@ class LocalVar:
         if instance is None:
             return self
         else:
-            return Memory(instance, self.bits_to_opcode[self.bits],
+            return Memory(instance, Memory.bits_to_opcode[self.bits],
                           instance.r10 + self.addr, self.signed)
 
     def __set__(self, instance, value):
-        bits = self.bits_to_opcode[self.bits]
+        bits = Memory.bits_to_opcode[self.bits]
         if isinstance(value, int):
             instance.append(Opcode.ST + bits, 10, 0, self.addr, value)
         else:
@@ -591,6 +591,78 @@ class HashMap(Map):
             setattr(ebpf, v.name, ebpf.__class__.__dict__[v.name].default)
 
 
+class ArrayGlobalVarDesc:
+    def __init__(self, map, position, size, signed):
+        self.map = map
+        self.position = position
+        self.signed = signed
+        self.size = size
+        self.fmt = {1: "B", 2: "H", 4: "I", 8: "Q"}[size]
+        if signed:
+            self.fmt = self.fmt.lower()
+
+    def __get__(self, ebpf, owner):
+        if ebpf is None:
+            return self
+        if ebpf.loaded:
+            data = ebpf.__dict__[self.map.name].data[
+                    self.position:self.position + self.size]
+            return unpack(self.fmt, data)[0]
+        return Memory(ebpf, Memory.bits_to_opcode[self.size * 8],
+                      ebpf.r0 + self.position, self.signed)
+
+    def __set_name__(self, owner, name):
+        self.name = name
+
+    def __set__(self, ebpf, value):
+        if ebpf.loaded:
+            ebpf.__dict__[self.map.name].data[
+                    self.position:self.position + self.size] = \
+                            pack(self.fmt, value)
+        else:
+            getattr(ebpf, f"m{self.size * 8}")[ebpf.r0 + self.position] = value
+
+
+class ArrayMapAccess:
+    def __init__(self, fd, size):
+        self.fd = fd
+        self.size = size
+
+    def read(self):
+        self.data = bpf.lookup_elem(self.fd, b"\0\0\0\0", self.size)
+
+    def write(self):
+        bpf.update_elem(self.fd, b"\0\0\0\0", self.data, 0)
+
+
+class ArrayMap(Map):
+    position = 0
+
+    def __init__(self):
+        self.vars = []
+
+    def globalVar(self, signed=False, size=4):
+        ret = ArrayGlobalVarDesc(self, self.position, size, signed)
+        self.position = (self.position + 2 * size - 1) & -size
+        self.vars.append(ret)
+        return ret
+
+    def __set_name__(self, owner, name):
+        self.name = name
+
+    def init(self, ebpf):
+        fd = bpf.create_map(2, 4, self.position, 1)
+        setattr(ebpf, self.name, ArrayMapAccess(fd, self.position))
+        with ebpf.save_registers(None), ebpf.get_stack(4) as stack:
+            ebpf.append(Opcode.ST, 10, 0, stack, 0)
+            ebpf.r1 = ebpf.get_fd(fd)
+            ebpf.r2 = ebpf.r10 + stack
+            ebpf.call(1)
+            with ebpf.If(ebpf.r0 == 0):
+                ebpf.exit()
+        ebpf.owners.add(0)
+
+
 class PseudoFd(Expression):
     def __init__(self, ebpf, fd):
         self.ebpf = ebpf
diff --git a/ebpf_test.py b/ebpf_test.py
index 5484ac3..471f783 100644
--- a/ebpf_test.py
+++ b/ebpf_test.py
@@ -2,7 +2,8 @@ from unittest import TestCase, main
 
 from . import ebpf
 from .ebpf import (
-    AssembleError, EBPF, HashMap, Opcode, OpcodeFlags, Opcode as O, LocalVar)
+    ArrayMap, AssembleError, EBPF, HashMap, Opcode, OpcodeFlags,
+    Opcode as O, LocalVar)
 from .bpf import ProgType, prog_test_run
 
 
@@ -487,6 +488,24 @@ class KernelTests(TestCase):
         prog_test_run(fd, 1000, 1000, 0, 0, 1)
         self.assertEqual(e.a, 31)
 
+    def test_arraymap(self):
+        class Global(EBPF):
+            map = ArrayMap()
+            a = map.globalVar()
+
+        e = Global(ProgType.XDP, "GPL")
+        e.a += 7
+        e.exit()
+
+        fd = e.load()
+        prog_test_run(fd, 1000, 1000, 0, 0, 1)
+        e.map.read()
+        e.a *= 2
+        e.map.write()
+        prog_test_run(fd, 1000, 1000, 0, 0, 1)
+        e.map.read()
+        self.assertEqual(e.a, 21)
+
     def test_minimal(self):
         class Global(EBPF):
             map = HashMap()
-- 
GitLab