diff --git a/ebpfcat/bpf.py b/ebpfcat/bpf.py index c25145532c00bfe79819282135c7d82f56f6f01a..14bab08bac201208a06889995cdccd5798d25fcf 100644 --- a/ebpfcat/bpf.py +++ b/ebpfcat/bpf.py @@ -149,7 +149,7 @@ def prog_load(prog_type, insns, license, if log_level != 0: return fd, the_logbuf.value.decode("utf8") else: - return fd + return fd, None def obj_pin(pathname, fd): pn = pathname.encode("utf8") diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py index 25b901912557ba854560acaa461cba59e6696acd..87ed97fdb80885dba5e49f392f29c3f2cc4b3ca4 100644 --- a/ebpfcat/ebpf.py +++ b/ebpfcat/ebpf.py @@ -19,6 +19,7 @@ __all__ = ["EBPF", "LocalVar", "prandom", "ktime"] +import os from abc import ABC, abstractmethod from collections import namedtuple from contextlib import contextmanager, ExitStack @@ -1318,16 +1319,24 @@ class EBPF: def load(self, log_level=0, log_size=10 * 4096): """load the program into the kernel""" - ret = bpf.prog_load(self.prog_type, self.assemble(), self.license, - log_level, log_size, self.kern_version, - name=self.name) + fd, log = bpf.prog_load(self.prog_type, self.assemble(), self.license, + log_level, log_size, self.kern_version, + name=self.name) self.loaded = True + self.file_descriptor = fd for v in self.__class__.__dict__.values(): if isinstance(v, Map): v.load(self) - return ret + return log + + def close(self): + os.close(self.file_descriptor) + self.file_descriptor = None + + def test_run(self, *args, **kwargs): + return bpf.prog_test_run(self.file_descriptor, *args, **kwargs) def jumpIf(self, comp): """jump if `comp` is true to a later defined `target`""" diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py index bc8823dae36bc9c43c2d000ad5ee2559645d53ca..139e192853a9c88eb43dfd181bb0611647940c87 100644 --- a/ebpfcat/ebpf_test.py +++ b/ebpfcat/ebpf_test.py @@ -24,7 +24,7 @@ from .ebpf import ( SubProgram, ktime) from .hashmap import HashMap from .xdp import XDP, PacketVar -from .bpf import ProgType, prog_test_run +from .bpf import ProgType opcodes = list((v.value, v) for v in Opcode) @@ -1062,10 +1062,10 @@ class KernelTests(TestCase): e.a += 7 e.exit() - fd, _ = e.load(log_level=1) - prog_test_run(fd, 1000, 1000, 0, 0, 1) + e.load(log_level=1) + e.test_run(1000, 1000, 0, 0, 1) e.a *= 2 - prog_test_run(fd, 1000, 1000, 0, 0, 1) + e.test_run(1000, 1000, 0, 0, 1) self.assertEqual(e.a, 31) self.assertEqual(e.b, 24) @@ -1096,8 +1096,8 @@ class KernelTests(TestCase): e.r0 = 55 e.exit() - fd, _ = e.load(log_level=1) - prog_test_run(fd, 1000, 1000, 100, 100, 1) + e.load(log_level=1) + e.test_run(1000, 1000, 100, 100, 1) self.assertEqual(e.ar, 7) self.assertEqual(e.aw, 11) self.assertEqual(s1.br, 33) @@ -1113,7 +1113,7 @@ class KernelTests(TestCase): self.assertEqual(s1.bw, 36) self.assertEqual(s2.br, 165) self.assertEqual(s2.bw, 36) - prog_test_run(fd, 1000, 1000, 0, 0, 1) + e.test_run(1000, 1000, 0, 0, 1) self.assertEqual(e.ar, 18) self.assertEqual(e.aw, 22) self.assertEqual(s1.br, 36) diff --git a/ebpfcat/ebpfcat.py b/ebpfcat/ebpfcat.py index 45b5cfe9c43437ad430c01fa2c2197e858f94b75..491e097bc17d05ffa178b7e83d4776a60c827969 100644 --- a/ebpfcat/ebpfcat.py +++ b/ebpfcat/ebpfcat.py @@ -424,9 +424,10 @@ class FastEtherCat(SimpleEtherCat): index = len(self.sync_groups) while index in self.sync_groups: index = (index + 1) % self.MAX_PROGS - fd, _ = sg.load(log_level=1) - update_elem(self.programs, pack("<I", index), pack("<I", fd), 0) - os.close(fd) + sg.load() + update_elem(self.programs, pack("<I", index), + pack("<I", sg.file_descriptor), 0) + sg.close() self.sync_groups[index] = sg try: yield index @@ -437,7 +438,7 @@ class FastEtherCat(SimpleEtherCat): await super().connect() self.ebpf = EtherXDP() self.ebpf.programs = self.programs - self.fd = await self.ebpf.attach(self.addr[0]) + await self.ebpf.attach(self.addr[0]) @asynccontextmanager async def run(self): diff --git a/ebpfcat/xdp.py b/ebpfcat/xdp.py index fe73266fbd37e7bf0e77bb8166ada5523e22c59b..99c91061e10ccd43a174f1d03a42040808b1c271 100644 --- a/ebpfcat/xdp.py +++ b/ebpfcat/xdp.py @@ -254,8 +254,8 @@ class XDP(EBPF): like ``"eth0"`` :param flags: one of the :class:`XDPFlags` """ ifindex = if_nametoindex(network) - fd, _ = self.load(log_level=self.ebpf_log_level) - await self._netlink(ifindex, fd, flags) + self.load(log_level=self.ebpf_log_level) + await self._netlink(ifindex, self.file_descriptor, flags) async def detach(self, network, flags=XDPFlags.SKB_MODE): """detach this program from a ``network`` @@ -277,11 +277,11 @@ class XDP(EBPF): like ``"eth0"`` :param flags: one of the :class:`XDPFlags` """ ifindex = if_nametoindex(network) - fd, _ = self.load(log_level=self.ebpf_log_level) + self.load(log_level=self.ebpf_log_level) try: - await self._netlink(ifindex, fd, flags) + await self._netlink(ifindex, self.file_descriptor, flags) finally: - os.close(fd) + self.close() try: yield finally: