From 6eb85bb1ec595302668635d4c7d4bb6bebbe8d63 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Mon, 18 Jan 2021 16:20:56 +0000
Subject: [PATCH] support SubPrograms

they can already have ArrayMaps, but not much more
---
 arraymap.py  | 48 +++++++++++++++++++++++++----------------
 ebpf.py      | 61 ++++++++++++++++++++++++++++++++++------------------
 ebpf_test.py | 21 +++++++++++++++---
 3 files changed, 87 insertions(+), 43 deletions(-)

diff --git a/arraymap.py b/arraymap.py
index 56f3536..633b0cb 100644
--- a/arraymap.py
+++ b/arraymap.py
@@ -1,13 +1,12 @@
 from struct import pack, unpack
 
-from .ebpf import FuncId, Map, Memory, Opcode
+from .ebpf import FuncId, Map, Memory, Opcode, SubProgram
 from .bpf import create_map, lookup_elem, MapType, update_elem
 
 
 class ArrayGlobalVarDesc:
-    def __init__(self, map, position, size, signed):
+    def __init__(self, map, size, signed):
         self.map = map
-        self.position = position
         self.signed = signed
         self.size = size
         self.fmt = {1: "B", 2: "H", 4: "I", 8: "Q"}[size]
@@ -17,23 +16,28 @@ class ArrayGlobalVarDesc:
     def __get__(self, ebpf, owner):
         if ebpf is None:
             return self
+        position = ebpf.__dict__[self.name]
+        if isinstance(ebpf, SubProgram):
+            ebpf = ebpf.ebpf
         if ebpf.loaded:
             data = ebpf.__dict__[self.map.name].data[
-                    self.position:self.position + self.size]
+                    position : position+self.size]
             return unpack(self.fmt, data)[0]
         return Memory(ebpf, Memory.bits_to_opcode[self.size * 8],
-                      ebpf.r0 + self.position, self.signed)
+                      ebpf.r0 + position, self.signed)
 
     def __set_name__(self, owner, name):
         self.name = name
 
     def __set__(self, ebpf, value):
+        position = ebpf.__dict__[self.name]
+        if isinstance(ebpf, SubProgram):
+            ebpf = ebpf.ebpf
         if ebpf.loaded:
             ebpf.__dict__[self.map.name].data[
-                    self.position:self.position + self.size] = \
-                            pack(self.fmt, value)
+                    position : position + self.size] = pack(self.fmt, value)
         else:
-            getattr(ebpf, f"m{self.size * 8}")[ebpf.r0 + self.position] = value
+            getattr(ebpf, f"m{self.size * 8}")[ebpf.r0 + position] = value
 
 
 class ArrayMapAccess:
@@ -49,23 +53,29 @@ class ArrayMapAccess:
 
 
 class ArrayMap(Map):
-    position = 0
-
-    def __init__(self):
-        self.vars = []
-
     def globalVar(self, signed=False, size=4):
-        ret = ArrayGlobalVarDesc(self, self.position, size, signed)
-        self.position = (self.position + 2 * size - 1) & -size
-        self.vars.append(ret)
-        return ret
+        return ArrayGlobalVarDesc(self, size, signed)
+
+    def add_program(self, owner, prog):
+        position = getattr(owner, self.name)
+        for k, v in prog.__class__.__dict__.items():
+            if not isinstance(v, ArrayGlobalVarDesc):
+                continue
+            prog.__dict__[k] = position
+            position = (position + 2 * v.size - 1) & -v.size
+        setattr(owner, self.name, position)
 
     def __set_name__(self, owner, name):
         self.name = name
 
     def init(self, ebpf):
-        fd = create_map(MapType.ARRAY, 4, self.position, 1)
-        setattr(ebpf, self.name, ArrayMapAccess(fd, self.position))
+        setattr(ebpf, self.name, 0)
+        self.add_program(ebpf, ebpf)
+        for prog in ebpf.subprograms:
+            self.add_program(ebpf, prog)
+        size = getattr(ebpf, self.name)
+        fd = create_map(MapType.ARRAY, 4, size, 1)
+        setattr(ebpf, self.name, ArrayMapAccess(fd, size))
         with ebpf.save_registers(list(range(6))), ebpf.get_stack(4) as stack:
             ebpf.append(Opcode.ST, 10, 0, stack, 0)
             ebpf.r1 = ebpf.get_fd(fd)
diff --git a/ebpf.py b/ebpf.py
index 9049ee4..f19904b 100644
--- a/ebpf.py
+++ b/ebpf.py
@@ -628,11 +628,37 @@ class Memory(Expression):
         return self.address.contains(no)
 
 
-class LocalVar:
+class MemoryDesc:
     def __init__(self, bits=32, signed=False):
         self.bits = bits
         self.signed = signed
 
+    def __get__(self, ebpf, owner):
+        if ebpf is None:
+            return self
+        elif isinstance(ebpf, SubProgram):
+            ebpf = ebpf.ebpf
+        return Memory(ebpf, Memory.bits_to_opcode[self.bits],
+                      ebpf.r[self.base_register] + self.addr,
+                      self.signed)
+
+    def __set__(self, ebpf, value):
+        if isinstance(ebpf, SubProgram):
+            ebpf = ebpf.ebpf
+        bits = Memory.bits_to_opcode[self.bits]
+        if isinstance(value, int):
+            ebpf.append(Opcode.ST + bits, self.base_register, 0,
+                        self.addr, value)
+        else:
+            with value.calculate(None, self.bits == 64, self.signed) \
+                    as (src, _, _):
+                ebpf.append(Opcode.STX + bits, self.base_register,
+                            src, self.addr, 0)
+
+
+class LocalVar(MemoryDesc):
+    base_register = 10
+
     def __set_name__(self, owner, name):
         size = int(self.bits // 8)
         owner.stack -= size
@@ -640,23 +666,8 @@ class LocalVar:
         self.addr = owner.stack
         self.name = name
 
-    def __get__(self, instance, owner):
-        if instance is None:
-            return self
-        else:
-            return Memory(instance, Memory.bits_to_opcode[self.bits],
-                          instance.r10 + self.addr, self.signed)
 
-    def __set__(self, instance, value):
-        bits = Memory.bits_to_opcode[self.bits]
-        if isinstance(value, int):
-            instance.append(Opcode.ST + bits, 10, 0, self.addr, value)
-        else:
-            with value.calculate(None, self.bits == 64, self.signed) \
-                    as (src, _, _):
-                instance.append(Opcode.STX + bits, 10, src, self.addr, 0)
-
-class MemoryDesc:
+class MemoryMap:
     def __init__(self, ebpf, bits):
         self.ebpf = ebpf
         self.bits = bits
@@ -794,10 +805,10 @@ class EBPF:
             self.name = name
         self.loaded = False
 
-        self.m8 = MemoryDesc(self, Opcode.B)
-        self.m16 = MemoryDesc(self, Opcode.H)
-        self.m32 = MemoryDesc(self, Opcode.W)
-        self.m64 = MemoryDesc(self, Opcode.DW)
+        self.m8 = MemoryMap(self, Opcode.B)
+        self.m16 = MemoryMap(self, Opcode.H)
+        self.m32 = MemoryMap(self, Opcode.W)
+        self.m64 = MemoryMap(self, Opcode.DW)
 
         self.r = RegisterArray(self, True, False)
         self.sr = RegisterArray(self, True, True)
@@ -806,6 +817,10 @@ class EBPF:
 
         self.owners = {1, 10}
 
+        self.subprograms = subprograms
+        for p in subprograms:
+            p.ebpf = self
+
         for v in self.__class__.__dict__.values():
             if isinstance(v, Map):
                 v.init(self)
@@ -927,3 +942,7 @@ for i in range(10):
 
 for i in range(10):
     setattr(EBPF, f"sw{i}", RegisterDesc(i, "sw"))
+
+
+class SubProgram:
+    pass
diff --git a/ebpf_test.py b/ebpf_test.py
index 50c8c1d..229588b 100644
--- a/ebpf_test.py
+++ b/ebpf_test.py
@@ -3,7 +3,8 @@ from unittest import TestCase, main
 from . import ebpf
 from .arraymap import ArrayMap
 from .ebpf import (
-    AssembleError, EBPF, FuncId, Opcode, OpcodeFlags, Opcode as O, LocalVar)
+    AssembleError, EBPF, FuncId, Opcode, OpcodeFlags, Opcode as O, LocalVar,
+    SubProgram)
 from .hashmap import HashMap
 from .xdp import XDP
 from .bpf import ProgType, prog_test_run
@@ -534,18 +535,32 @@ class KernelTests(TestCase):
             map = ArrayMap()
             a = map.globalVar()
 
-        e = Global(ProgType.XDP, "GPL")
+        class Sub(SubProgram):
+            b = Global.map.globalVar()
+
+            def program(self):
+                self.b -= -33
+
+        s1 = Sub()
+        s2 = Sub()
+        e = Global(ProgType.XDP, "GPL", subprograms=[s1, s2])
         e.a += 7
+        s1.program()
+        s2.program()
         e.exit()
 
-        fd = e.load()
+        fd, _ = e.load(log_level=1)
         prog_test_run(fd, 1000, 1000, 0, 0, 1)
         e.map.read()
         e.a *= 2
+        s1.b = 3
+        s2.b *= 5
         e.map.write()
         prog_test_run(fd, 1000, 1000, 0, 0, 1)
         e.map.read()
         self.assertEqual(e.a, 21)
+        self.assertEqual(s1.b, 36)
+        self.assertEqual(s2.b, 5 * 33 + 33)
 
     def test_minimal(self):
         class Global(XDP):
-- 
GitLab