From 3f2b353e1ddce9fded999292e0ca74de4f980c94 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Thu, 17 Oct 2024 11:27:22 +0000
Subject: [PATCH] do not return file descriptor for load

we returned two different things whether we were keeping logs
upon loading or not. This makes it hard to program around.

Now only return logs, and store the file descriptor internally.
---
 ebpfcat/bpf.py       |  2 +-
 ebpfcat/ebpf.py      | 17 +++++++++++++----
 ebpfcat/ebpf_test.py | 14 +++++++-------
 ebpfcat/ebpfcat.py   |  9 +++++----
 ebpfcat/xdp.py       | 10 +++++-----
 5 files changed, 31 insertions(+), 21 deletions(-)

diff --git a/ebpfcat/bpf.py b/ebpfcat/bpf.py
index c251455..14bab08 100644
--- a/ebpfcat/bpf.py
+++ b/ebpfcat/bpf.py
@@ -149,7 +149,7 @@ def prog_load(prog_type, insns, license,
     if log_level != 0:
         return fd, the_logbuf.value.decode("utf8")
     else:
-        return fd
+        return fd, None
 
 def obj_pin(pathname, fd):
     pn = pathname.encode("utf8")
diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py
index 25b9019..87ed97f 100644
--- a/ebpfcat/ebpf.py
+++ b/ebpfcat/ebpf.py
@@ -19,6 +19,7 @@
 
 __all__ = ["EBPF", "LocalVar", "prandom", "ktime"]
 
+import os
 from abc import ABC, abstractmethod
 from collections import namedtuple
 from contextlib import contextmanager, ExitStack
@@ -1318,16 +1319,24 @@ class EBPF:
 
     def load(self, log_level=0, log_size=10 * 4096):
         """load the program into the kernel"""
-        ret = bpf.prog_load(self.prog_type, self.assemble(), self.license,
-                            log_level, log_size, self.kern_version,
-                            name=self.name)
+        fd, log = bpf.prog_load(self.prog_type, self.assemble(), self.license,
+                                log_level, log_size, self.kern_version,
+                                name=self.name)
         self.loaded = True
+        self.file_descriptor = fd
 
         for v in self.__class__.__dict__.values():
             if isinstance(v, Map):
                 v.load(self)
 
-        return ret
+        return log
+
+    def close(self):
+        os.close(self.file_descriptor)
+        self.file_descriptor = None
+
+    def test_run(self, *args, **kwargs):
+        return bpf.prog_test_run(self.file_descriptor, *args, **kwargs)
 
     def jumpIf(self, comp):
         """jump if `comp` is true to a later defined `target`"""
diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py
index bc8823d..139e192 100644
--- a/ebpfcat/ebpf_test.py
+++ b/ebpfcat/ebpf_test.py
@@ -24,7 +24,7 @@ from .ebpf import (
     SubProgram, ktime)
 from .hashmap import HashMap
 from .xdp import XDP, PacketVar
-from .bpf import ProgType, prog_test_run
+from .bpf import ProgType
 
 
 opcodes = list((v.value, v) for v in Opcode)
@@ -1062,10 +1062,10 @@ class KernelTests(TestCase):
         e.a += 7
         e.exit()
 
-        fd, _ = e.load(log_level=1)
-        prog_test_run(fd, 1000, 1000, 0, 0, 1)
+        e.load(log_level=1)
+        e.test_run(1000, 1000, 0, 0, 1)
         e.a *= 2
-        prog_test_run(fd, 1000, 1000, 0, 0, 1)
+        e.test_run(1000, 1000, 0, 0, 1)
         self.assertEqual(e.a, 31)
         self.assertEqual(e.b, 24)
 
@@ -1096,8 +1096,8 @@ class KernelTests(TestCase):
         e.r0 = 55
         e.exit()
 
-        fd, _ = e.load(log_level=1)
-        prog_test_run(fd, 1000, 1000, 100, 100, 1)
+        e.load(log_level=1)
+        e.test_run(1000, 1000, 100, 100, 1)
         self.assertEqual(e.ar, 7)
         self.assertEqual(e.aw, 11)
         self.assertEqual(s1.br, 33)
@@ -1113,7 +1113,7 @@ class KernelTests(TestCase):
         self.assertEqual(s1.bw, 36)
         self.assertEqual(s2.br, 165)
         self.assertEqual(s2.bw, 36)
-        prog_test_run(fd, 1000, 1000, 0, 0, 1)
+        e.test_run(1000, 1000, 0, 0, 1)
         self.assertEqual(e.ar, 18)
         self.assertEqual(e.aw, 22)
         self.assertEqual(s1.br, 36)
diff --git a/ebpfcat/ebpfcat.py b/ebpfcat/ebpfcat.py
index 45b5cfe..491e097 100644
--- a/ebpfcat/ebpfcat.py
+++ b/ebpfcat/ebpfcat.py
@@ -424,9 +424,10 @@ class FastEtherCat(SimpleEtherCat):
         index = len(self.sync_groups)
         while index in self.sync_groups:
             index = (index + 1) % self.MAX_PROGS
-        fd, _ = sg.load(log_level=1)
-        update_elem(self.programs, pack("<I", index), pack("<I", fd), 0)
-        os.close(fd)
+        sg.load()
+        update_elem(self.programs, pack("<I", index),
+                    pack("<I", sg.file_descriptor), 0)
+        sg.close()
         self.sync_groups[index] = sg
         try:
             yield index
@@ -437,7 +438,7 @@ class FastEtherCat(SimpleEtherCat):
         await super().connect()
         self.ebpf = EtherXDP()
         self.ebpf.programs = self.programs
-        self.fd = await self.ebpf.attach(self.addr[0])
+        await self.ebpf.attach(self.addr[0])
 
     @asynccontextmanager
     async def run(self):
diff --git a/ebpfcat/xdp.py b/ebpfcat/xdp.py
index fe73266..99c9106 100644
--- a/ebpfcat/xdp.py
+++ b/ebpfcat/xdp.py
@@ -254,8 +254,8 @@ class XDP(EBPF):
            like ``"eth0"``
         :param flags: one of the :class:`XDPFlags` """
         ifindex = if_nametoindex(network)
-        fd, _ = self.load(log_level=self.ebpf_log_level)
-        await self._netlink(ifindex, fd, flags)
+        self.load(log_level=self.ebpf_log_level)
+        await self._netlink(ifindex, self.file_descriptor, flags)
 
     async def detach(self, network, flags=XDPFlags.SKB_MODE):
         """detach this program from a ``network``
@@ -277,11 +277,11 @@ class XDP(EBPF):
            like ``"eth0"``
         :param flags: one of the :class:`XDPFlags` """
         ifindex = if_nametoindex(network)
-        fd, _ = self.load(log_level=self.ebpf_log_level)
+        self.load(log_level=self.ebpf_log_level)
         try:
-            await self._netlink(ifindex, fd, flags)
+            await self._netlink(ifindex, self.file_descriptor, flags)
         finally:
-            os.close(fd)
+            self.close()
         try:
             yield
         finally:
-- 
GitLab