Skip to content
Snippets Groups Projects
Commit 6eb85bb1 authored by Martin Teichmann's avatar Martin Teichmann
Browse files

support SubPrograms

they can already have ArrayMaps, but not much more
parent 23e2a80c
No related branches found
No related tags found
No related merge requests found
from struct import pack, unpack
from .ebpf import FuncId, Map, Memory, Opcode
from .ebpf import FuncId, Map, Memory, Opcode, SubProgram
from .bpf import create_map, lookup_elem, MapType, update_elem
class ArrayGlobalVarDesc:
def __init__(self, map, position, size, signed):
def __init__(self, map, size, signed):
self.map = map
self.position = position
self.signed = signed
self.size = size
self.fmt = {1: "B", 2: "H", 4: "I", 8: "Q"}[size]
......@@ -17,23 +16,28 @@ class ArrayGlobalVarDesc:
def __get__(self, ebpf, owner):
if ebpf is None:
return self
position = ebpf.__dict__[self.name]
if isinstance(ebpf, SubProgram):
ebpf = ebpf.ebpf
if ebpf.loaded:
data = ebpf.__dict__[self.map.name].data[
self.position:self.position + self.size]
position : position+self.size]
return unpack(self.fmt, data)[0]
return Memory(ebpf, Memory.bits_to_opcode[self.size * 8],
ebpf.r0 + self.position, self.signed)
ebpf.r0 + position, self.signed)
def __set_name__(self, owner, name):
self.name = name
def __set__(self, ebpf, value):
position = ebpf.__dict__[self.name]
if isinstance(ebpf, SubProgram):
ebpf = ebpf.ebpf
if ebpf.loaded:
ebpf.__dict__[self.map.name].data[
self.position:self.position + self.size] = \
pack(self.fmt, value)
position : position + self.size] = pack(self.fmt, value)
else:
getattr(ebpf, f"m{self.size * 8}")[ebpf.r0 + self.position] = value
getattr(ebpf, f"m{self.size * 8}")[ebpf.r0 + position] = value
class ArrayMapAccess:
......@@ -49,23 +53,29 @@ class ArrayMapAccess:
class ArrayMap(Map):
position = 0
def __init__(self):
self.vars = []
def globalVar(self, signed=False, size=4):
ret = ArrayGlobalVarDesc(self, self.position, size, signed)
self.position = (self.position + 2 * size - 1) & -size
self.vars.append(ret)
return ret
return ArrayGlobalVarDesc(self, size, signed)
def add_program(self, owner, prog):
position = getattr(owner, self.name)
for k, v in prog.__class__.__dict__.items():
if not isinstance(v, ArrayGlobalVarDesc):
continue
prog.__dict__[k] = position
position = (position + 2 * v.size - 1) & -v.size
setattr(owner, self.name, position)
def __set_name__(self, owner, name):
self.name = name
def init(self, ebpf):
fd = create_map(MapType.ARRAY, 4, self.position, 1)
setattr(ebpf, self.name, ArrayMapAccess(fd, self.position))
setattr(ebpf, self.name, 0)
self.add_program(ebpf, ebpf)
for prog in ebpf.subprograms:
self.add_program(ebpf, prog)
size = getattr(ebpf, self.name)
fd = create_map(MapType.ARRAY, 4, size, 1)
setattr(ebpf, self.name, ArrayMapAccess(fd, size))
with ebpf.save_registers(list(range(6))), ebpf.get_stack(4) as stack:
ebpf.append(Opcode.ST, 10, 0, stack, 0)
ebpf.r1 = ebpf.get_fd(fd)
......
......@@ -628,11 +628,37 @@ class Memory(Expression):
return self.address.contains(no)
class LocalVar:
class MemoryDesc:
def __init__(self, bits=32, signed=False):
self.bits = bits
self.signed = signed
def __get__(self, ebpf, owner):
if ebpf is None:
return self
elif isinstance(ebpf, SubProgram):
ebpf = ebpf.ebpf
return Memory(ebpf, Memory.bits_to_opcode[self.bits],
ebpf.r[self.base_register] + self.addr,
self.signed)
def __set__(self, ebpf, value):
if isinstance(ebpf, SubProgram):
ebpf = ebpf.ebpf
bits = Memory.bits_to_opcode[self.bits]
if isinstance(value, int):
ebpf.append(Opcode.ST + bits, self.base_register, 0,
self.addr, value)
else:
with value.calculate(None, self.bits == 64, self.signed) \
as (src, _, _):
ebpf.append(Opcode.STX + bits, self.base_register,
src, self.addr, 0)
class LocalVar(MemoryDesc):
base_register = 10
def __set_name__(self, owner, name):
size = int(self.bits // 8)
owner.stack -= size
......@@ -640,23 +666,8 @@ class LocalVar:
self.addr = owner.stack
self.name = name
def __get__(self, instance, owner):
if instance is None:
return self
else:
return Memory(instance, Memory.bits_to_opcode[self.bits],
instance.r10 + self.addr, self.signed)
def __set__(self, instance, value):
bits = Memory.bits_to_opcode[self.bits]
if isinstance(value, int):
instance.append(Opcode.ST + bits, 10, 0, self.addr, value)
else:
with value.calculate(None, self.bits == 64, self.signed) \
as (src, _, _):
instance.append(Opcode.STX + bits, 10, src, self.addr, 0)
class MemoryDesc:
class MemoryMap:
def __init__(self, ebpf, bits):
self.ebpf = ebpf
self.bits = bits
......@@ -794,10 +805,10 @@ class EBPF:
self.name = name
self.loaded = False
self.m8 = MemoryDesc(self, Opcode.B)
self.m16 = MemoryDesc(self, Opcode.H)
self.m32 = MemoryDesc(self, Opcode.W)
self.m64 = MemoryDesc(self, Opcode.DW)
self.m8 = MemoryMap(self, Opcode.B)
self.m16 = MemoryMap(self, Opcode.H)
self.m32 = MemoryMap(self, Opcode.W)
self.m64 = MemoryMap(self, Opcode.DW)
self.r = RegisterArray(self, True, False)
self.sr = RegisterArray(self, True, True)
......@@ -806,6 +817,10 @@ class EBPF:
self.owners = {1, 10}
self.subprograms = subprograms
for p in subprograms:
p.ebpf = self
for v in self.__class__.__dict__.values():
if isinstance(v, Map):
v.init(self)
......@@ -927,3 +942,7 @@ for i in range(10):
for i in range(10):
setattr(EBPF, f"sw{i}", RegisterDesc(i, "sw"))
class SubProgram:
pass
......@@ -3,7 +3,8 @@ from unittest import TestCase, main
from . import ebpf
from .arraymap import ArrayMap
from .ebpf import (
AssembleError, EBPF, FuncId, Opcode, OpcodeFlags, Opcode as O, LocalVar)
AssembleError, EBPF, FuncId, Opcode, OpcodeFlags, Opcode as O, LocalVar,
SubProgram)
from .hashmap import HashMap
from .xdp import XDP
from .bpf import ProgType, prog_test_run
......@@ -534,18 +535,32 @@ class KernelTests(TestCase):
map = ArrayMap()
a = map.globalVar()
e = Global(ProgType.XDP, "GPL")
class Sub(SubProgram):
b = Global.map.globalVar()
def program(self):
self.b -= -33
s1 = Sub()
s2 = Sub()
e = Global(ProgType.XDP, "GPL", subprograms=[s1, s2])
e.a += 7
s1.program()
s2.program()
e.exit()
fd = e.load()
fd, _ = e.load(log_level=1)
prog_test_run(fd, 1000, 1000, 0, 0, 1)
e.map.read()
e.a *= 2
s1.b = 3
s2.b *= 5
e.map.write()
prog_test_run(fd, 1000, 1000, 0, 0, 1)
e.map.read()
self.assertEqual(e.a, 21)
self.assertEqual(s1.b, 36)
self.assertEqual(s2.b, 5 * 33 + 33)
def test_minimal(self):
class Global(XDP):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment