From a010cb0c17aac3b3b2035a825f86139707458c2e Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Mon, 21 Dec 2020 17:53:08 +0000
Subject: [PATCH] could test an ebpf program

---
 bpf.py     | 48 +++++++++++++++++++++++++++++++++++++-----------
 ebpf.py    |  2 +-
 ebpfcat.py | 41 +++++++++++++++++++++++++++++++++++++++++
 xdp.py     | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 4 files changed, 131 insertions(+), 12 deletions(-)
 create mode 100644 ebpfcat.py
 create mode 100644 xdp.py

diff --git a/bpf.py b/bpf.py
index 0c2cb07..3123e35 100644
--- a/bpf.py
+++ b/bpf.py
@@ -1,6 +1,6 @@
-from ctypes import CDLL, c_int, get_errno, cast, c_void_p, create_string_buffer
+from ctypes import CDLL, c_int, get_errno, cast, c_void_p, create_string_buffer, c_char_p
 from enum import Enum
-from struct import pack
+from struct import pack, unpack
 
 from os import strerror
 
@@ -37,21 +37,27 @@ def addrof(ptr):
 
 def bpf(cmd, fmt, *args):
     attr = pack(fmt, *args)
+    attr = create_string_buffer(attr, len(attr))
+    print(unpack(fmt, attr.raw))
     ret = libc.syscall(386, c_int(cmd), attr, len(attr))
+    print(unpack(fmt, attr.raw))
     if ret == -1:
         raise OSError(get_errno(), strerror(get_errno()))
-    return ret
+    return ret, unpack(fmt, attr.raw)
 
 def create_map(map_type, key_size, value_size, max_entries):
-    return bpf(0, "IIII", map_type, key_size, value_size, max_entries)
+    return bpf(0, "IIII", map_type, key_size, value_size, max_entries)[0]
 
 def lookup_elem(fd, key, size):
     value = create_string_buffer(size)
-    bpf(1, "IQQQ", fd, addrof(key), addrof(value), 0)
-    return value.value
+    ret, _ = bpf(1, "IQQQ", fd, addrof(key), addrof(value), 0)
+    if ret == 0:
+        return value.raw
+    else:
+        return None
 
 def update_elem(fd, key, value, flags):
-    return bpf(2, "IQQQ", fd, addrof(key), addrof(value), flags)
+    return bpf(2, "IQQQ", fd, addrof(key), addrof(value), flags)[0]
 
 def prog_load(prog_type, insns, license,
               log_level=0, log_size=4096, kern_version=0):
@@ -63,15 +69,35 @@ def prog_load(prog_type, insns, license,
         log_buf = addrof(the_logbuf)
     license = license.encode("utf8")
     try:
-        bpf(5, "IIQQIIQI", prog_type.value, int(len(insns) // 8),
-            addrof(insns), addrof(license), log_level, log_size, log_buf,
-            kern_version)
+        fd, _ = bpf(5, "IIQQIIQI", prog_type.value, int(len(insns) // 8),
+                    addrof(insns), addrof(license), log_level, log_size,
+                    log_buf, kern_version)
     except OSError as e:
         if log_level != 0:
             raise BPFError(e.errno, the_logbuf.value.decode("utf8"))
         raise
     if log_level != 0:
-        return the_logbuf.value.decode("utf8")
+        return fd, the_logbuf.value.decode("utf8")
+    else:
+        return fd
+
+def prog_test_run(fd, data_in, data_out, ctx_in, ctx_out,
+                  repeat=1):
+    if isinstance(data_in, int):
+        data_in = create_string_buffer(data_in)
+    else:
+        data_in = create_string_buffer(data_in, len(data_in))
+    if isinstance(ctx_in, int):
+        ctx_in = create_string_buffer(ctx_in)
+    else:
+        ctx_in = create_string_buffer(ctx_in, len(ctx_in))
+    data_out = create_string_buffer(data_out)
+    ctx_out = create_string_buffer(ctx_out)
+    ret, (_, retval, _, _, _, _, _, duration, _, _, _, _) = bpf(
+            10, "IIIIQQIIIIQQ20x", fd, 0, len(data_in), len(data_out),
+            addrof(data_in), addrof(data_out), repeat, 0, 0, 0, 0, 0)
+            #len(ctx_in), len(ctx_out), addrof(ctx_in), addrof(ctx_out))
+    return ret, retval, duration, data_out.value, ctx_out.value
 
 if __name__ == "__main__":
     fd = create_map(1, 4, 4, 10)
diff --git a/ebpf.py b/ebpf.py
index 1d60a17..b8eb903 100644
--- a/ebpf.py
+++ b/ebpf.py
@@ -254,7 +254,7 @@ class EBPF:
         self.append(0x95, 0, 0, 0, 0)
 
 
-for i in range(10):
+for i in range(11):
     setattr(EBPF, f"r{i}", RegisterDesc(i, True))
 
 for i in range(10):
diff --git a/ebpfcat.py b/ebpfcat.py
new file mode 100644
index 0000000..80528d4
--- /dev/null
+++ b/ebpfcat.py
@@ -0,0 +1,41 @@
+from .xdp import set_link_xdp_fd
+from .ebpf import EBPF
+from .bpf import ProgType, create_map, update_elem, prog_test_run, lookup_elem
+
+def script():
+    fd = create_map(1, 4, 4, 7)
+    update_elem(fd, b"AAAA", b"BBBB", 0)
+
+    e = EBPF(ProgType.XDP, "GPL")
+    e.r1 = e.get_fd(fd)
+    e.r2 = e.r10
+    e.r2 += -8
+    e.m32[e.r10 - 8] = 0x41414141
+    e.call(1)
+    with e.If(e.r0 != 0):
+        e.r1 = e.get_fd(fd)
+        e.r2 = e.r10
+        e.r2 += -8
+        e.r3 = e.m32[e.r0]
+        e.r3 += 1
+        e.m32[e.r10 - 16] = e.r3
+        e.r3 = e.r10
+        e.r3 += -16
+        e.r4 = 0
+        e.call(2)
+    e.r0 = 2  # XDP_PASS
+    e.exit()
+    return fd, e
+
+async def install_ebpf(network):
+    map_fd, e = script()
+    fd, disas = e.load(log_level=1)
+    prog_test_run(fd, 512, 512, 512, 512, repeat=10)
+    print("bla", lookup_elem(map_fd, b"AAAAA", 4))
+    await set_link_xdp_fd("eth0", fd)
+    return map_fd
+
+if __name__ == "__main__":
+    from asyncio import get_event_loop
+    loop = get_event_loop()
+    loop.run_until_complete(install_ebpf("eth0"))
diff --git a/xdp.py b/xdp.py
new file mode 100644
index 0000000..e90e7ac
--- /dev/null
+++ b/xdp.py
@@ -0,0 +1,52 @@
+from asyncio import DatagramProtocol, Future, get_event_loop
+from socket import AF_NETLINK, NETLINK_ROUTE, if_nametoindex
+from struct import pack, unpack
+
+async def set_link_xdp_fd(network, fd):
+    ifindex = if_nametoindex(network)
+    future = Future()
+    transport, proto = await get_event_loop().create_datagram_endpoint(
+            lambda: XDRFD(ifindex, fd, future),
+            family=AF_NETLINK, proto=NETLINK_ROUTE)
+
+class XDRFD(DatagramProtocol):
+    def __init__(self, ifindex, fd, future):
+        self.ifindex = ifindex
+        self.fd = fd
+        self.seq = None
+        self.future = future
+
+    def connection_made(self, transport):
+        transport.get_extra_info("socket").bind((0, 0))
+        self.transport = transport
+        p = pack("IHHIIBxHiIIHHHHi",
+                16,  # length of if struct
+                19,  # RTM_SETLINK
+                5,  # REQ | ACK
+                1,  # sequence number
+                0,  # pid
+                0,  # AF_UNSPEC
+                0,  # type
+                self.ifindex,
+                0,  #flags
+                0,  #change
+                0x802B,  # NLA_F_NESTED | IFLA_XDP
+                12,  # length of field
+                1,  # IFLA_XDP_FD
+                8,  # length of field
+                self.fd)
+        transport.sendto(p, (0, 0))
+
+    def datagram_received(self, data, addr):
+        pos = 0
+        while (pos < len(data)):
+            ln, type, flags, seq, pid = unpack("IHHII", data[pos : pos+16])
+            if type == 3:  # DONE
+                self.future.set_result(0)
+            elif type == 2:  # ERROR
+                self.future.set_result(-1)
+            elif flags & 2 == 0:  # not a multipart message
+                self.future.set_result(0)
+            pos += ln
+                
+        
-- 
GitLab