From dcea6fa85ea5cacf5a5429fe84d2f24100f90545 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@gmail.com>
Date: Fri, 10 Feb 2023 09:33:13 +0000
Subject: [PATCH] unload XDP programs after use

this involves restructuring quite a bit, probably with compatibility
problems.
---
 ebpfcat/bpf.py      |  3 +++
 ebpfcat/ebpfcat.py  | 44 +++++++++++++++++++++++++++++++++++---------
 ebpfcat/ethercat.py |  2 +-
 ebpfcat/xdp.py      | 31 ++++++++++++++++++++++++++-----
 4 files changed, 65 insertions(+), 15 deletions(-)

diff --git a/ebpfcat/bpf.py b/ebpfcat/bpf.py
index 2da7527..ee07c4d 100644
--- a/ebpfcat/bpf.py
+++ b/ebpfcat/bpf.py
@@ -123,6 +123,9 @@ def update_elem(fd, key, value, flags):
         addr = addrof(value)
     return bpf(2, "IQQQ", fd, addrof(key), addr, flags)[0]
 
+def delete_elem(fd, key):
+    return bpf(3, "IQ", fd, addrof(key))[0]
+
 def prog_load(prog_type, insns, license,
               log_level=0, log_size=4096, kern_version=0, flags=0,
               name="", ifindex=0, attach_type=0):
diff --git a/ebpfcat/ebpfcat.py b/ebpfcat/ebpfcat.py
index 1db24b1..a3a3c3d 100644
--- a/ebpfcat/ebpfcat.py
+++ b/ebpfcat/ebpfcat.py
@@ -17,6 +17,8 @@
 
 """The high-level API for EtherCAT loops"""
 from asyncio import ensure_future, gather, wait_for, TimeoutError
+from contextlib import asynccontextmanager, contextmanager
+import os
 from struct import pack, unpack, calcsize, pack_into, unpack_from
 from time import time
 from .arraymap import ArrayMap, ArrayGlobalVarDesc
@@ -24,7 +26,8 @@ from .ethercat import ECCmd, EtherCat, Packet, Terminal
 from .ebpf import FuncId, MemoryDesc, SubProgram, prandom
 from .xdp import XDP, XDPExitCode
 from .bpf import (
-    ProgType, MapType, create_map, update_elem, prog_test_run, lookup_elem)
+    ProgType, MapType, create_map, delete_elem, update_elem, prog_test_run,
+    lookup_elem)
 
 
 class PacketDesc:
@@ -311,14 +314,19 @@ class FastEtherCat(SimpleEtherCat):
         self.programs = create_map(MapType.PROG_ARRAY, 4, 4, self.MAX_PROGS)
         self.sync_groups = {}
 
-    def register_sync_group(self, sg, packet):
+    @contextmanager
+    def register_sync_group(self, sg):
         index = len(self.sync_groups)
         while index in self.sync_groups:
             index = (index + 1) % self.MAX_PROGS
         fd, _ = sg.load(log_level=1)
         update_elem(self.programs, pack("<I", index), pack("<I", fd), 0)
+        os.close(fd)
         self.sync_groups[index] = sg
-        return index
+        try:
+            yield index
+        finally:
+            delete_elem(self.programs, pack("<I", index))
 
     async def connect(self):
         await super().connect()
@@ -326,6 +334,18 @@ class FastEtherCat(SimpleEtherCat):
         self.ebpf.programs = self.programs
         self.fd = await self.ebpf.attach(self.addr[0])
 
+    @asynccontextmanager
+    async def run(self):
+        await super().connect()
+        self.ebpf = EtherXDP()
+        self.ebpf.programs = self.programs
+        async with self.ebpf.run(self.addr[0]):
+            try:
+                yield
+            finally:
+                for v in self.sync_groups.values():
+                    v.cancel()
+
 
 class SyncGroupBase:
     missed_counter = 0
@@ -405,22 +425,28 @@ class FastSyncGroup(SyncGroupBase, XDP):
         self.exit(XDPExitCode.TX)
 
     async def run(self):
-        self.ec.send_packet(self.asm_packet)
-        self.ec.send_packet(self.asm_packet)
-        await super().run()
+        with self.ec.register_sync_group(self) as self.packet_index:
+            self.asm_packet = self.packet.assemble(self.packet_index)
+            self.ec.send_packet(self.asm_packet)
+            self.ec.send_packet(self.asm_packet)
+            await super().run()
 
     def update_devices(self, data):
         if data[3] & 1:
             self.current_data = data
+        elif self.current_data is None:
+            return self.asm_packet
         for dev in self.devices:
             dev.fast_update()
         return self.asm_packet
 
     def start(self):
         self.allocate()
-        self.packet_index = self.ec.register_sync_group(self, self.packet)
-        self.asm_packet = self.packet.assemble(self.packet_index)
-        ensure_future(self.run())
+        self.task = ensure_future(self.run())
+        return self.task
+
+    def cancel(self):
+        self.task.cancel()
 
     def allocate(self):
         self.packet = Packet()
diff --git a/ebpfcat/ethercat.py b/ebpfcat/ethercat.py
index d626644..09f5c47 100644
--- a/ebpfcat/ethercat.py
+++ b/ebpfcat/ethercat.py
@@ -262,11 +262,11 @@ class EtherCat(Protocol):
         :param network: the name of the network adapter, like "eth0"
         """
         self.addr = (network, 0x88A4, 0, 0, b"\xff\xff\xff\xff\xff\xff")
-        self.send_queue = Queue()
         self.wait_futures = {}
 
     async def connect(self):
         """connect to the EtherCAT loop"""
+        self.send_queue = Queue()
         await get_event_loop().create_datagram_endpoint(
             lambda: self, family=AF_PACKET, proto=0xA488)
 
diff --git a/ebpfcat/xdp.py b/ebpfcat/xdp.py
index daa4c50..7d462c9 100644
--- a/ebpfcat/xdp.py
+++ b/ebpfcat/xdp.py
@@ -18,7 +18,8 @@
 """support for XDP programs"""
 from asyncio import DatagramProtocol, Future, get_event_loop
 from enum import Enum
-from contextlib import contextmanager
+from contextlib import asynccontextmanager, contextmanager
+import os
 from socket import AF_NETLINK, NETLINK_ROUTE, if_nametoindex
 import socket
 from struct import pack, unpack
@@ -49,6 +50,7 @@ class XDRFD(DatagramProtocol):
         sock.setsockopt(270, 11, 1)
         sock.bind((0, 0))
         self.transport = transport
+        # this was adopted from xdp1_user.c
         p = pack("IHHIIBxHiIiHHHHiHHI",
                 # NLmsghdr
                 52,  # length of if struct
@@ -147,13 +149,32 @@ class XDP(EBPF):
 
         self.packetSize = PacketSize(self)
 
-    async def attach(self, network):
-        """attach this program to a `network`"""
-        ifindex = if_nametoindex(network)
-        fd, _ = self.load(log_level=1)
+    async def _netlink(self, ifindex, fd):
         future = Future()
         transport, proto = await get_event_loop().create_datagram_endpoint(
                 lambda: XDRFD(ifindex, fd, future),
                 family=AF_NETLINK, proto=NETLINK_ROUTE)
         await future
         transport.get_extra_info("socket").close()
+
+    async def attach(self, network):
+        """attach this program to a `network`"""
+        ifindex = if_nametoindex(network)
+        fd, _ = self.load(log_level=1)
+        await self._netlink(ifindex, fd)
+
+    async def detach(self, network):
+        """attach this program from a `network`"""
+        ifindex = if_nametoindex(network)
+        await self._netlink(ifindex, -1)
+
+    @asynccontextmanager
+    async def run(self, network):
+        ifindex = if_nametoindex(network)
+        fd, _ = self.load(log_level=1)
+        await self._netlink(ifindex, fd)
+        os.close(fd)
+        try:
+            yield
+        finally:
+            await self._netlink(ifindex, -1)
-- 
GitLab