From caf4073de59f1940c3e142a1f2015acbdd56c9b2 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Tue, 29 Dec 2020 19:24:30 +0000
Subject: [PATCH] factor out XDP

---
 ebpf.py      |  5 +++++
 ebpf_test.py |  7 ++++---
 xdp.py       | 45 +++++++++++++++++++++++++++++++--------------
 3 files changed, 40 insertions(+), 17 deletions(-)

diff --git a/ebpf.py b/ebpf.py
index b904b87..3e7523d 100644
--- a/ebpf.py
+++ b/ebpf.py
@@ -751,6 +751,11 @@ class EBPF:
             if isinstance(v, Map):
                 v.init(self)
 
+        self.program()
+
+    def program(self):
+        pass
+
     def append(self, opcode, dst, src, off, imm):
         self.opcodes.append(Instruction(opcode, dst, src, off, imm))
 
diff --git a/ebpf_test.py b/ebpf_test.py
index bdd085d..9e078e3 100644
--- a/ebpf_test.py
+++ b/ebpf_test.py
@@ -3,7 +3,8 @@ from unittest import TestCase, main
 from . import ebpf
 from .ebpf import (
     ArrayMap, AssembleError, EBPF, HashMap, Opcode, OpcodeFlags,
-    Opcode as O, LocalVar, XDP)
+    Opcode as O, LocalVar)
+from .xdp import XDP
 from .bpf import ProgType, prog_test_run
 
 
@@ -510,11 +511,11 @@ class KernelTests(TestCase):
         self.assertEqual(e.a, 21)
 
     def test_minimal(self):
-        class Global(EBPF):
+        class Global(XDP):
             map = HashMap()
             a = map.globalVar()
 
-        e = Global(ProgType.XDP, "GPL")
+        e = Global(license="GPL")
         e.a += 1
         e.exit()
         print(e.opcodes)
diff --git a/xdp.py b/xdp.py
index 4ad8384..b9e0145 100644
--- a/xdp.py
+++ b/xdp.py
@@ -3,14 +3,9 @@ from socket import AF_NETLINK, NETLINK_ROUTE, if_nametoindex
 import socket
 from struct import pack, unpack
 
-async def set_link_xdp_fd(network, fd):
-    ifindex = if_nametoindex(network)
-    future = Future()
-    transport, proto = await get_event_loop().create_datagram_endpoint(
-            lambda: XDRFD(ifindex, fd, future),
-            family=AF_NETLINK, proto=NETLINK_ROUTE)
-    await future
-    transport.get_extra_info("socket").close()
+from .ebpf import EBPF, Memory, MemoryDesc, Opcode
+from .bpf import ProgType
+
 
 class XDRFD(DatagramProtocol):
     def __init__(self, ifindex, fd, future):
@@ -47,25 +42,47 @@ class XDRFD(DatagramProtocol):
                 8,
                 3,  # IFLA_XDP_FLAGS,
                 2)
-        print("send", len(p), p)
         transport.sendto(p, (0, 0))
 
     def datagram_received(self, data, addr):
         pos = 0
-        print("received", data)
         while (pos < len(data)):
             ln, type, flags, seq, pid = unpack("IHHII", data[pos : pos+16])
-            print(f"  {ln} {type} {flags:x} {seq} {pid}")
             if type == 3:  # DONE
                 self.future.set_result(0)
             elif type == 2:  # ERROR
                 errno, *args = unpack("iIHHII", data[pos+16 : pos+36])
-                print("ERROR", errno, args)
                 if errno != 0:
                     self.future.set_result(errno)
             if flags & 2 == 0:  # not a multipart message
-                print("not multipart")
                 self.future.set_result(0)
             pos += ln
                 
-        
+
+class PacketDesc(MemoryDesc):
+    def __setitem__(self, addr, value):
+        super().__setitem__(self.ebpf.r9 + addr, value)
+
+    def __getitem__(self, addr):
+        return Memory(self.ebpf, self.bits, self.ebpf.r9 + addr)
+
+
+class XDP(EBPF):
+    def __init__(self, **kwargs):
+        super().__init__(prog_type=ProgType.XDP, **kwargs)
+        self.r9 = self.m32[self.r1]
+
+        self.packet8 = MemoryDesc(self, Opcode.B)
+        self.packet16 = MemoryDesc(self, Opcode.H)
+        self.packet32 = MemoryDesc(self, Opcode.W)
+        self.packet64 = MemoryDesc(self, Opcode.DW)
+
+    async def attach(self, network):
+        ifindex = if_nametoindex(network)
+        fd = self.load()
+        future = Future()
+        transport, proto = await get_event_loop().create_datagram_endpoint(
+                lambda: XDRFD(ifindex, fd, future),
+                family=AF_NETLINK, proto=NETLINK_ROUTE)
+        await future
+        transport.get_extra_info("socket").close()
-- 
GitLab