From d3df05abaa8f145658873a61174a3b4454c7e00c Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Mon, 18 Jan 2021 16:45:49 +0000
Subject: [PATCH] make local variables work for subprograms

---
 ebpf.py      | 34 ++++++++++++++++++++++------------
 ebpf_test.py | 24 ++++++++++++++++++++++++
 xdp.py       |  2 +-
 3 files changed, 47 insertions(+), 13 deletions(-)

diff --git a/ebpf.py b/ebpf.py
index f19904b..19c900c 100644
--- a/ebpf.py
+++ b/ebpf.py
@@ -633,27 +633,31 @@ class MemoryDesc:
         self.bits = bits
         self.signed = signed
 
-    def __get__(self, ebpf, owner):
-        if ebpf is None:
+    def __get__(self, instance, owner):
+        if instance is None:
             return self
-        elif isinstance(ebpf, SubProgram):
-            ebpf = ebpf.ebpf
+        elif isinstance(instance, SubProgram):
+            ebpf = instance.ebpf
+        else:
+            ebpf = instance
         return Memory(ebpf, Memory.bits_to_opcode[self.bits],
-                      ebpf.r[self.base_register] + self.addr,
+                      ebpf.r[self.base_register] + self.addr(instance),
                       self.signed)
 
-    def __set__(self, ebpf, value):
-        if isinstance(ebpf, SubProgram):
-            ebpf = ebpf.ebpf
+    def __set__(self, instance, value):
+        if isinstance(instance, SubProgram):
+            ebpf = instance.ebpf
+        else:
+            ebpf = instance
         bits = Memory.bits_to_opcode[self.bits]
         if isinstance(value, int):
             ebpf.append(Opcode.ST + bits, self.base_register, 0,
-                        self.addr, value)
+                        self.addr(instance), value)
         else:
             with value.calculate(None, self.bits == 64, self.signed) \
                     as (src, _, _):
                 ebpf.append(Opcode.STX + bits, self.base_register,
-                            src, self.addr, 0)
+                            src, self.addr(instance), 0)
 
 
 class LocalVar(MemoryDesc):
@@ -663,9 +667,15 @@ class LocalVar(MemoryDesc):
         size = int(self.bits // 8)
         owner.stack -= size
         owner.stack &= -size
-        self.addr = owner.stack
+        self.relative_addr = owner.stack
         self.name = name
 
+    def addr(self, instance):
+        if isinstance(instance, SubProgram):
+            return (instance.ebpf.stack & -8) + self.relative_addr
+        else:
+            return self.relative_addr
+
 
 class MemoryMap:
     def __init__(self, ebpf, bits):
@@ -945,4 +955,4 @@ for i in range(10):
 
 
 class SubProgram:
-    pass
+    stack = 0
diff --git a/ebpf_test.py b/ebpf_test.py
index 229588b..b2ac19e 100644
--- a/ebpf_test.py
+++ b/ebpf_test.py
@@ -157,6 +157,30 @@ class Tests(TestCase):
             Instruction(opcode=O.REG+O.STX, dst=10, src=0, off=-4, imm=0),
             Instruction(opcode=O.DW+O.STX, dst=10, src=1, off=-16, imm=0)])
 
+    def test_local_subprog(self):
+        class Local(EBPF):
+            a = LocalVar(32, False)
+
+        class Sub(SubProgram):
+            b = LocalVar(32, False)
+
+            def program(self):
+                self.b *= 3
+
+        s1 = Sub()
+        s2 = Sub()
+        e = Local(ProgType.XDP, "GPL", subprograms=[s1, s2])
+        e.a = 5
+        s1.b = 3
+        e.r3 = s1.b
+        s2.b = 7
+        self.assertEqual(e.opcodes, [
+            Instruction(opcode=O.W+O.ST, dst=10, src=0, off=-4, imm=5),
+            Instruction(opcode=O.W+O.ST, dst=10, src=0, off=-12, imm=3),
+            Instruction(opcode=O.W+O.LD, dst=3, src=10, off=-12, imm=0),
+            Instruction(opcode=O.W+O.ST, dst=10, src=0, off=-12, imm=7)])
+
+
     def test_jump(self):
         e = EBPF()
         e.owners = set(range(11))
diff --git a/xdp.py b/xdp.py
index efc0998..0a9753a 100644
--- a/xdp.py
+++ b/xdp.py
@@ -4,7 +4,7 @@ from socket import AF_NETLINK, NETLINK_ROUTE, if_nametoindex
 import socket
 from struct import pack, unpack
 
-from .ebpf import EBPF, Expression, Memory, MemoryDesc, Opcode, Comparison
+from .ebpf import EBPF, Expression, Memory, Opcode, Comparison
 from .bpf import ProgType
 
 
-- 
GitLab