diff --git a/ebpf.py b/ebpf.py index 4ff7710e6773296b9b9db9ed42c923be741c72b4..ccd39edd7cbf83866457517832d19ca0f15a3ecf 100644 --- a/ebpf.py +++ b/ebpf.py @@ -706,29 +706,77 @@ class PseudoFd(Expression): class RegisterDesc: - def __init__(self, no, long, signed=False): + def __init__(self, no, array): self.no = no - self.long = long - self.signed = signed + self.array = array def __get__(self, instance, owner=None): if instance is None: return self else: - return Register(self.no, instance, self.long, self.signed) + return getattr(instance, self.array)[self.no] def __set__(self, instance, value): - instance.owners.add(self.no) + getattr(instance, self.array)[self.no] = value + + +class RegisterArray: + def __init__(self, ebpf, long, signed): + self.ebpf = ebpf + self.long = long + self.signed = signed + + def __setitem__(self, no, value): + self.ebpf.owners.add(no) if isinstance(value, int): - instance._load_value(self.no, value) + self.ebpf._load_value(no, value) elif isinstance(value, Expression): - with value.calculate(self.no, self.long, self.signed, True): + with value.calculate(no, self.long, self.signed, True): pass - elif isinstance(value, Instruction): - instance.opcodes.append(value) else: raise AssembleError("cannot compile") - + + def __getitem__(self, no): + return Register(no, self.ebpf, self.long, self.signed) + + + +class Temporary(Register): + def __init__(self, ebpf, long, signed): + super().__init__(None, ebpf, long, signed) + self.nos = [] + self.gfrs = [] + + def __enter__(self): + gfr = self.ebpf.get_free_register(None) + self.nos.append(self.no) + self.no = gfr.__enter__() + self.gfrs.append(gfr) + + def __exit__(self, a, b, c): + gfr = self.gfrs.pop() + gfr.__exit__(a, b, c) + self.no = self.nos.pop() + + +class TemporaryDesc(RegisterDesc): + def __set_name__(self, owner, name): + self.name = name + + def __get__(self, instance, owner=None): + if instance is None: + return self + arr = getattr(instance, self.array) + ret = instance.__dict__.get(self.name, None) + if ret is None: + ret = instance.__dict__[self.name] = \ + Temporary(instance, arr.long, arr.signed) + return ret + + def __set__(self, instance, value): + no = getattr(instance, self.name).no + getattr(instance, self.array)[no] = value + class EBPF: stack = 0 @@ -752,14 +800,17 @@ class EBPF: self.m32 = MemoryDesc(self, Opcode.W) self.m64 = MemoryDesc(self, Opcode.DW) + self.r = RegisterArray(self, True, False) + self.sr = RegisterArray(self, True, True) + self.w = RegisterArray(self, False, False) + self.sw = RegisterArray(self, False, True) + self.owners = {1, 10} for v in self.__class__.__dict__.values(): if isinstance(v, Map): v.init(self) - self.program() - def program(self): pass @@ -767,6 +818,7 @@ class EBPF: self.opcodes.append(Instruction(opcode, dst, src, off, imm)) def assemble(self): + self.program() return b"".join( pack("<BBHI", i.opcode.value, i.dst | i.src << 4, i.off % 0x10000, i.imm % 0x100000000) @@ -809,7 +861,9 @@ class EBPF: self.owners.add(0) self.owners -= set(range(1, 6)) - def exit(self): + def exit(self, no=None): + if no is not None: + self.r0 = no self.append(Opcode.EXIT, 0, 0, 0, 0) @contextmanager @@ -856,12 +910,20 @@ class EBPF: yield self.stack self.stack = oldstack + tmp = TemporaryDesc(None, "r") + stmp = TemporaryDesc(None, "sr") + wtmp = TemporaryDesc(None, "w") + swtmp = TemporaryDesc(None, "sw") + for i in range(11): - setattr(EBPF, f"r{i}", RegisterDesc(i, True)) + setattr(EBPF, f"r{i}", RegisterDesc(i, "r")) + +for i in range(10): + setattr(EBPF, f"sr{i}", RegisterDesc(i, "sr")) for i in range(10): - setattr(EBPF, f"sr{i}", RegisterDesc(i, True, True)) + setattr(EBPF, f"w{i}", RegisterDesc(i, "w")) for i in range(10): - setattr(EBPF, f"w{i}", RegisterDesc(i, False)) + setattr(EBPF, f"sw{i}", RegisterDesc(i, "sw")) diff --git a/ebpf_test.py b/ebpf_test.py index 9e078e382f6303f83111893a821a481f33498d18..90b4ef6d59a574a60d98c727b089d8dfc5c74dbb 100644 --- a/ebpf_test.py +++ b/ebpf_test.py @@ -472,6 +472,42 @@ class Tests(TestCase): with self.assertRaises(AssembleError): e.r8 = e.r2 + def test_temporary(self): + e = EBPF() + e.r0 = 7 + with e.tmp: + e.tmp = 3 + e.r3 = e.tmp + with e.tmp: + e.tmp = 5 + e.r7 = e.tmp + e.tmp = 2 + e.r3 = e.tmp + self.assertEqual(e.opcodes, [ + Instruction(opcode=O.MOV+O.LONG, dst=0, src=0, off=0, imm=7), + Instruction(opcode=O.MOV+O.LONG, dst=2, src=0, off=0, imm=3), + Instruction(opcode=O.MOV+O.LONG+O.REG, dst=3, src=2, off=0, imm=0), + Instruction(opcode=O.MOV+O.LONG, dst=4, src=0, off=0, imm=5), + Instruction(opcode=O.MOV+O.LONG+O.REG, dst=7, src=4, off=0, imm=0), + Instruction(opcode=O.MOV+O.LONG, dst=2, src=0, off=0, imm=2), + Instruction(opcode=O.MOV+O.LONG+O.REG, dst=3, src=2, off=0, imm=0) + ]) + + def test_xdp(self): + e = XDP(license="GPL") + with e.packetSize > 100 as p: + e.r3 = p.H[22] + with p.Else(): + e.r3 = 77 + self.assertEqual(e.opcodes, [ + Instruction(opcode=O.LD+O.W, dst=0, src=1, off=0, imm=0), + Instruction(opcode=O.LD+O.W, dst=2, src=1, off=4, imm=0), + Instruction(opcode=O.LD+O.W, dst=3, src=1, off=0, imm=0), + Instruction(opcode=O.ADD+O.LONG, dst=3, src=0, off=0, imm=100), + Instruction(opcode=O.REG+O.JLE, dst=2, src=3, off=2, imm=0), + Instruction(opcode=O.REG+O.LD, dst=3, src=0, off=22, imm=0), + Instruction(opcode=O.JMP, dst=0, src=0, off=1, imm=0), + Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=77)]) class KernelTests(TestCase): def test_hashmap(self): diff --git a/xdp.py b/xdp.py index 6c66043d8d92f71b4c25dee1eeeb8a67f962f85f..fd48740a9861ba358667a02a377ecb11c91d9c71 100644 --- a/xdp.py +++ b/xdp.py @@ -1,9 +1,10 @@ from asyncio import DatagramProtocol, Future, get_event_loop +from contextlib import contextmanager from socket import AF_NETLINK, NETLINK_ROUTE, if_nametoindex import socket from struct import pack, unpack -from .ebpf import EBPF, Memory, MemoryDesc, Opcode +from .ebpf import EBPF, Expression, Memory, MemoryDesc, Opcode, Comparison from .bpf import ProgType @@ -58,49 +59,67 @@ class XDRFD(DatagramProtocol): self.future.set_result(0) pos += ln -class Packet(Expression): - def __init__(self, ebpf, bits, addr): + +class PacketArray: + def __init__(self, ebpf, no, memory): + self.ebpf = ebpf + self.no = no + self.memory = memory + + def __getitem__(self, value): + return self.memory[self.ebpf.r[self.no] + value] + + def __setitem__(self, value): + self.memory[self.ebpf.r[self.no]] = value + + +class Packet: + def __init__(self, ebpf, comp, no): + self.ebpf = ebpf + self.comp = comp + self.no = no + + self.B = PacketArray(self.ebpf, self.no, self.ebpf.m8) + self.H = PacketArray(self.ebpf, self.no, self.ebpf.m16) + self.W = PacketArray(self.ebpf, self.no, self.ebpf.m32) + self.DW = PacketArray(self.ebpf, self.no, self.ebpf.m64) + + def Else(self): + return self.comp.Else() + + +class PacketSize: + def __init__(self, ebpf): self.ebpf = ebpf - self.bits = bits - self.address = addr - self.signed = False @contextmanager - def get_address(self, dst, long, signed, force=False): + def __lt__(self, value): + e = self.ebpf + with e.tmp: + e.tmp = e.m32[e.r1] + with e.If(e.m32[e.r1 + 4] < e.m32[e.r1] + value) as comp: + yield Packet(e, comp, e.tmp.no) + + @contextmanager + def __gt__(self, value): e = self.ebpf - bits = Memory.bits_to_opcode[self.bits] - with e.get_free_register(dst) as reg: - e.r[reg] = e.m32[e.r1] + self.address - with e.If(e.r[reg] + int(self.bits // 8) <= e.m32[e.r1 + 4]) as c: - if force and dst != reg: - e.r[dst] = e.r[reg] - reg = dst - with c.Else(): - e.exit(2) - yield reg, bits - - def contains(self, no): - return no == 1 or (not isinstance(self.address, int) - and self.address.contains(no)) - - -class PacketDesc(MemoryDesc): - def __setitem__(self, addr, value): - super().__setitem__(self.ebpf.r9 + addr, value) - - def __getitem__(self, addr): - return Memory(self.ebpf, self.bits, self.ebpf.r9 + addr) + with e.tmp: + e.tmp = e.m32[e.r1] + with e.If(e.m32[e.r1 + 4] > e.m32[e.r1] + value) as comp: + yield Packet(e, comp, e.tmp.no) + + def __le__(self, value): + return self < value + 1 + + def __ge__(self, value): + return self > value - 1 class XDP(EBPF): def __init__(self, **kwargs): super().__init__(prog_type=ProgType.XDP, **kwargs) - self.r9 = self.m32[self.r1] - self.packet8 = MemoryDesc(self, Opcode.B) - self.packet16 = MemoryDesc(self, Opcode.H) - self.packet32 = MemoryDesc(self, Opcode.W) - self.packet64 = MemoryDesc(self, Opcode.DW) + self.packetSize = PacketSize(self) async def attach(self, network): ifindex = if_nametoindex(network)