diff --git a/ebpf.py b/ebpf.py index b904b87b22df8ec9c5098facf7b7fb066e484f36..3e7523d6a7995f6d9b293929aae8606ee44e3a21 100644 --- a/ebpf.py +++ b/ebpf.py @@ -751,6 +751,11 @@ class EBPF: if isinstance(v, Map): v.init(self) + self.program() + + def program(self): + pass + def append(self, opcode, dst, src, off, imm): self.opcodes.append(Instruction(opcode, dst, src, off, imm)) diff --git a/ebpf_test.py b/ebpf_test.py index bdd085d84cbf7d95cb54d620320345099f2aed58..9e078e382f6303f83111893a821a481f33498d18 100644 --- a/ebpf_test.py +++ b/ebpf_test.py @@ -3,7 +3,8 @@ from unittest import TestCase, main from . import ebpf from .ebpf import ( ArrayMap, AssembleError, EBPF, HashMap, Opcode, OpcodeFlags, - Opcode as O, LocalVar, XDP) + Opcode as O, LocalVar) +from .xdp import XDP from .bpf import ProgType, prog_test_run @@ -510,11 +511,11 @@ class KernelTests(TestCase): self.assertEqual(e.a, 21) def test_minimal(self): - class Global(EBPF): + class Global(XDP): map = HashMap() a = map.globalVar() - e = Global(ProgType.XDP, "GPL") + e = Global(license="GPL") e.a += 1 e.exit() print(e.opcodes) diff --git a/xdp.py b/xdp.py index 4ad838406049431938c294be28eb1a57732bbbe1..b9e014590c837ef1a958774b29ac1f69e874de12 100644 --- a/xdp.py +++ b/xdp.py @@ -3,14 +3,9 @@ from socket import AF_NETLINK, NETLINK_ROUTE, if_nametoindex import socket from struct import pack, unpack -async def set_link_xdp_fd(network, fd): - ifindex = if_nametoindex(network) - future = Future() - transport, proto = await get_event_loop().create_datagram_endpoint( - lambda: XDRFD(ifindex, fd, future), - family=AF_NETLINK, proto=NETLINK_ROUTE) - await future - transport.get_extra_info("socket").close() +from .ebpf import EBPF, Memory, MemoryDesc, Opcode +from .bpf import ProgType + class XDRFD(DatagramProtocol): def __init__(self, ifindex, fd, future): @@ -47,25 +42,47 @@ class XDRFD(DatagramProtocol): 8, 3, # IFLA_XDP_FLAGS, 2) - print("send", len(p), p) transport.sendto(p, (0, 0)) def datagram_received(self, data, addr): pos = 0 - print("received", data) while (pos < len(data)): ln, type, flags, seq, pid = unpack("IHHII", data[pos : pos+16]) - print(f" {ln} {type} {flags:x} {seq} {pid}") if type == 3: # DONE self.future.set_result(0) elif type == 2: # ERROR errno, *args = unpack("iIHHII", data[pos+16 : pos+36]) - print("ERROR", errno, args) if errno != 0: self.future.set_result(errno) if flags & 2 == 0: # not a multipart message - print("not multipart") self.future.set_result(0) pos += ln - + +class PacketDesc(MemoryDesc): + def __setitem__(self, addr, value): + super().__setitem__(self.ebpf.r9 + addr, value) + + def __getitem__(self, addr): + return Memory(self.ebpf, self.bits, self.ebpf.r9 + addr) + + +class XDP(EBPF): + def __init__(self, **kwargs): + super().__init__(prog_type=ProgType.XDP, **kwargs) + self.r9 = self.m32[self.r1] + + self.packet8 = MemoryDesc(self, Opcode.B) + self.packet16 = MemoryDesc(self, Opcode.H) + self.packet32 = MemoryDesc(self, Opcode.W) + self.packet64 = MemoryDesc(self, Opcode.DW) + + async def attach(self, network): + ifindex = if_nametoindex(network) + fd = self.load() + future = Future() + transport, proto = await get_event_loop().create_datagram_endpoint( + lambda: XDRFD(ifindex, fd, future), + family=AF_NETLINK, proto=NETLINK_ROUTE) + await future + transport.get_extra_info("socket").close()