Skip to content
Snippets Groups Projects
Commit 4239c7b1 authored by Martin Teichmann's avatar Martin Teichmann
Browse files

support packet addressing

this also adds a lot of other details
parent b441cffc
No related branches found
No related tags found
No related merge requests found
...@@ -706,29 +706,77 @@ class PseudoFd(Expression): ...@@ -706,29 +706,77 @@ class PseudoFd(Expression):
class RegisterDesc: class RegisterDesc:
def __init__(self, no, long, signed=False): def __init__(self, no, array):
self.no = no self.no = no
self.long = long self.array = array
self.signed = signed
def __get__(self, instance, owner=None): def __get__(self, instance, owner=None):
if instance is None: if instance is None:
return self return self
else: else:
return Register(self.no, instance, self.long, self.signed) return getattr(instance, self.array)[self.no]
def __set__(self, instance, value): def __set__(self, instance, value):
instance.owners.add(self.no) getattr(instance, self.array)[self.no] = value
class RegisterArray:
def __init__(self, ebpf, long, signed):
self.ebpf = ebpf
self.long = long
self.signed = signed
def __setitem__(self, no, value):
self.ebpf.owners.add(no)
if isinstance(value, int): if isinstance(value, int):
instance._load_value(self.no, value) self.ebpf._load_value(no, value)
elif isinstance(value, Expression): elif isinstance(value, Expression):
with value.calculate(self.no, self.long, self.signed, True): with value.calculate(no, self.long, self.signed, True):
pass pass
elif isinstance(value, Instruction):
instance.opcodes.append(value)
else: else:
raise AssembleError("cannot compile") raise AssembleError("cannot compile")
def __getitem__(self, no):
return Register(no, self.ebpf, self.long, self.signed)
class Temporary(Register):
def __init__(self, ebpf, long, signed):
super().__init__(None, ebpf, long, signed)
self.nos = []
self.gfrs = []
def __enter__(self):
gfr = self.ebpf.get_free_register(None)
self.nos.append(self.no)
self.no = gfr.__enter__()
self.gfrs.append(gfr)
def __exit__(self, a, b, c):
gfr = self.gfrs.pop()
gfr.__exit__(a, b, c)
self.no = self.nos.pop()
class TemporaryDesc(RegisterDesc):
def __set_name__(self, owner, name):
self.name = name
def __get__(self, instance, owner=None):
if instance is None:
return self
arr = getattr(instance, self.array)
ret = instance.__dict__.get(self.name, None)
if ret is None:
ret = instance.__dict__[self.name] = \
Temporary(instance, arr.long, arr.signed)
return ret
def __set__(self, instance, value):
no = getattr(instance, self.name).no
getattr(instance, self.array)[no] = value
class EBPF: class EBPF:
stack = 0 stack = 0
...@@ -752,14 +800,17 @@ class EBPF: ...@@ -752,14 +800,17 @@ class EBPF:
self.m32 = MemoryDesc(self, Opcode.W) self.m32 = MemoryDesc(self, Opcode.W)
self.m64 = MemoryDesc(self, Opcode.DW) self.m64 = MemoryDesc(self, Opcode.DW)
self.r = RegisterArray(self, True, False)
self.sr = RegisterArray(self, True, True)
self.w = RegisterArray(self, False, False)
self.sw = RegisterArray(self, False, True)
self.owners = {1, 10} self.owners = {1, 10}
for v in self.__class__.__dict__.values(): for v in self.__class__.__dict__.values():
if isinstance(v, Map): if isinstance(v, Map):
v.init(self) v.init(self)
self.program()
def program(self): def program(self):
pass pass
...@@ -767,6 +818,7 @@ class EBPF: ...@@ -767,6 +818,7 @@ class EBPF:
self.opcodes.append(Instruction(opcode, dst, src, off, imm)) self.opcodes.append(Instruction(opcode, dst, src, off, imm))
def assemble(self): def assemble(self):
self.program()
return b"".join( return b"".join(
pack("<BBHI", i.opcode.value, i.dst | i.src << 4, pack("<BBHI", i.opcode.value, i.dst | i.src << 4,
i.off % 0x10000, i.imm % 0x100000000) i.off % 0x10000, i.imm % 0x100000000)
...@@ -809,7 +861,9 @@ class EBPF: ...@@ -809,7 +861,9 @@ class EBPF:
self.owners.add(0) self.owners.add(0)
self.owners -= set(range(1, 6)) self.owners -= set(range(1, 6))
def exit(self): def exit(self, no=None):
if no is not None:
self.r0 = no
self.append(Opcode.EXIT, 0, 0, 0, 0) self.append(Opcode.EXIT, 0, 0, 0, 0)
@contextmanager @contextmanager
...@@ -856,12 +910,20 @@ class EBPF: ...@@ -856,12 +910,20 @@ class EBPF:
yield self.stack yield self.stack
self.stack = oldstack self.stack = oldstack
tmp = TemporaryDesc(None, "r")
stmp = TemporaryDesc(None, "sr")
wtmp = TemporaryDesc(None, "w")
swtmp = TemporaryDesc(None, "sw")
for i in range(11): for i in range(11):
setattr(EBPF, f"r{i}", RegisterDesc(i, True)) setattr(EBPF, f"r{i}", RegisterDesc(i, "r"))
for i in range(10):
setattr(EBPF, f"sr{i}", RegisterDesc(i, "sr"))
for i in range(10): for i in range(10):
setattr(EBPF, f"sr{i}", RegisterDesc(i, True, True)) setattr(EBPF, f"w{i}", RegisterDesc(i, "w"))
for i in range(10): for i in range(10):
setattr(EBPF, f"w{i}", RegisterDesc(i, False)) setattr(EBPF, f"sw{i}", RegisterDesc(i, "sw"))
...@@ -472,6 +472,42 @@ class Tests(TestCase): ...@@ -472,6 +472,42 @@ class Tests(TestCase):
with self.assertRaises(AssembleError): with self.assertRaises(AssembleError):
e.r8 = e.r2 e.r8 = e.r2
def test_temporary(self):
e = EBPF()
e.r0 = 7
with e.tmp:
e.tmp = 3
e.r3 = e.tmp
with e.tmp:
e.tmp = 5
e.r7 = e.tmp
e.tmp = 2
e.r3 = e.tmp
self.assertEqual(e.opcodes, [
Instruction(opcode=O.MOV+O.LONG, dst=0, src=0, off=0, imm=7),
Instruction(opcode=O.MOV+O.LONG, dst=2, src=0, off=0, imm=3),
Instruction(opcode=O.MOV+O.LONG+O.REG, dst=3, src=2, off=0, imm=0),
Instruction(opcode=O.MOV+O.LONG, dst=4, src=0, off=0, imm=5),
Instruction(opcode=O.MOV+O.LONG+O.REG, dst=7, src=4, off=0, imm=0),
Instruction(opcode=O.MOV+O.LONG, dst=2, src=0, off=0, imm=2),
Instruction(opcode=O.MOV+O.LONG+O.REG, dst=3, src=2, off=0, imm=0)
])
def test_xdp(self):
e = XDP(license="GPL")
with e.packetSize > 100 as p:
e.r3 = p.H[22]
with p.Else():
e.r3 = 77
self.assertEqual(e.opcodes, [
Instruction(opcode=O.LD+O.W, dst=0, src=1, off=0, imm=0),
Instruction(opcode=O.LD+O.W, dst=2, src=1, off=4, imm=0),
Instruction(opcode=O.LD+O.W, dst=3, src=1, off=0, imm=0),
Instruction(opcode=O.ADD+O.LONG, dst=3, src=0, off=0, imm=100),
Instruction(opcode=O.REG+O.JLE, dst=2, src=3, off=2, imm=0),
Instruction(opcode=O.REG+O.LD, dst=3, src=0, off=22, imm=0),
Instruction(opcode=O.JMP, dst=0, src=0, off=1, imm=0),
Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=77)])
class KernelTests(TestCase): class KernelTests(TestCase):
def test_hashmap(self): def test_hashmap(self):
......
from asyncio import DatagramProtocol, Future, get_event_loop from asyncio import DatagramProtocol, Future, get_event_loop
from contextlib import contextmanager
from socket import AF_NETLINK, NETLINK_ROUTE, if_nametoindex from socket import AF_NETLINK, NETLINK_ROUTE, if_nametoindex
import socket import socket
from struct import pack, unpack from struct import pack, unpack
from .ebpf import EBPF, Memory, MemoryDesc, Opcode from .ebpf import EBPF, Expression, Memory, MemoryDesc, Opcode, Comparison
from .bpf import ProgType from .bpf import ProgType
...@@ -58,49 +59,67 @@ class XDRFD(DatagramProtocol): ...@@ -58,49 +59,67 @@ class XDRFD(DatagramProtocol):
self.future.set_result(0) self.future.set_result(0)
pos += ln pos += ln
class Packet(Expression):
def __init__(self, ebpf, bits, addr): class PacketArray:
def __init__(self, ebpf, no, memory):
self.ebpf = ebpf
self.no = no
self.memory = memory
def __getitem__(self, value):
return self.memory[self.ebpf.r[self.no] + value]
def __setitem__(self, value):
self.memory[self.ebpf.r[self.no]] = value
class Packet:
def __init__(self, ebpf, comp, no):
self.ebpf = ebpf
self.comp = comp
self.no = no
self.B = PacketArray(self.ebpf, self.no, self.ebpf.m8)
self.H = PacketArray(self.ebpf, self.no, self.ebpf.m16)
self.W = PacketArray(self.ebpf, self.no, self.ebpf.m32)
self.DW = PacketArray(self.ebpf, self.no, self.ebpf.m64)
def Else(self):
return self.comp.Else()
class PacketSize:
def __init__(self, ebpf):
self.ebpf = ebpf self.ebpf = ebpf
self.bits = bits
self.address = addr
self.signed = False
@contextmanager @contextmanager
def get_address(self, dst, long, signed, force=False): def __lt__(self, value):
e = self.ebpf
with e.tmp:
e.tmp = e.m32[e.r1]
with e.If(e.m32[e.r1 + 4] < e.m32[e.r1] + value) as comp:
yield Packet(e, comp, e.tmp.no)
@contextmanager
def __gt__(self, value):
e = self.ebpf e = self.ebpf
bits = Memory.bits_to_opcode[self.bits] with e.tmp:
with e.get_free_register(dst) as reg: e.tmp = e.m32[e.r1]
e.r[reg] = e.m32[e.r1] + self.address with e.If(e.m32[e.r1 + 4] > e.m32[e.r1] + value) as comp:
with e.If(e.r[reg] + int(self.bits // 8) <= e.m32[e.r1 + 4]) as c: yield Packet(e, comp, e.tmp.no)
if force and dst != reg:
e.r[dst] = e.r[reg] def __le__(self, value):
reg = dst return self < value + 1
with c.Else():
e.exit(2) def __ge__(self, value):
yield reg, bits return self > value - 1
def contains(self, no):
return no == 1 or (not isinstance(self.address, int)
and self.address.contains(no))
class PacketDesc(MemoryDesc):
def __setitem__(self, addr, value):
super().__setitem__(self.ebpf.r9 + addr, value)
def __getitem__(self, addr):
return Memory(self.ebpf, self.bits, self.ebpf.r9 + addr)
class XDP(EBPF): class XDP(EBPF):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(prog_type=ProgType.XDP, **kwargs) super().__init__(prog_type=ProgType.XDP, **kwargs)
self.r9 = self.m32[self.r1]
self.packet8 = MemoryDesc(self, Opcode.B) self.packetSize = PacketSize(self)
self.packet16 = MemoryDesc(self, Opcode.H)
self.packet32 = MemoryDesc(self, Opcode.W)
self.packet64 = MemoryDesc(self, Opcode.DW)
async def attach(self, network): async def attach(self, network):
ifindex = if_nametoindex(network) ifindex = if_nametoindex(network)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment