Skip to content
Snippets Groups Projects
Commit 04df2a4b authored by Martin Teichmann's avatar Martin Teichmann
Browse files

make the array map work

parent dbf16e61
No related branches found
No related tags found
No related merge requests found
......@@ -50,7 +50,7 @@ def lookup_elem(fd, key, size):
value = create_string_buffer(size)
ret, _ = bpf(1, "IQQQ", fd, addrof(key), addrof(value), 0)
if ret == 0:
return value.raw
return value
else:
return None
......@@ -100,4 +100,7 @@ def prog_test_run(fd, data_in, data_out, ctx_in, ctx_out,
if __name__ == "__main__":
fd = create_map(1, 4, 4, 10)
update_elem(fd, b"asdf", b"ckde", 0)
print(lookup_elem(fd, b"asdf", 4))
ret = lookup_elem(fd, b"asdf", 4)
ret[2:4] = b"kk"
update_elem(fd, b"asdf", ret, 0)
print(lookup_elem(fd, b"asdf", 4).raw)
......@@ -422,6 +422,8 @@ class Register(Expression):
class Memory(Expression):
bits_to_opcode = {32: Opcode.W, 16: Opcode.H, 8: Opcode.B, 64: Opcode.DW}
def __init__(self, ebpf, bits, address, signed=False):
self.ebpf = ebpf
self.bits = bits
......@@ -446,8 +448,6 @@ class Memory(Expression):
class LocalVar:
bits_to_opcode = {32: Opcode.W, 16: Opcode.H, 8: Opcode.B, 64: Opcode.DW}
def __init__(self, bits=32, signed=False):
self.bits = bits
self.signed = signed
......@@ -463,11 +463,11 @@ class LocalVar:
if instance is None:
return self
else:
return Memory(instance, self.bits_to_opcode[self.bits],
return Memory(instance, Memory.bits_to_opcode[self.bits],
instance.r10 + self.addr, self.signed)
def __set__(self, instance, value):
bits = self.bits_to_opcode[self.bits]
bits = Memory.bits_to_opcode[self.bits]
if isinstance(value, int):
instance.append(Opcode.ST + bits, 10, 0, self.addr, value)
else:
......@@ -591,6 +591,78 @@ class HashMap(Map):
setattr(ebpf, v.name, ebpf.__class__.__dict__[v.name].default)
class ArrayGlobalVarDesc:
def __init__(self, map, position, size, signed):
self.map = map
self.position = position
self.signed = signed
self.size = size
self.fmt = {1: "B", 2: "H", 4: "I", 8: "Q"}[size]
if signed:
self.fmt = self.fmt.lower()
def __get__(self, ebpf, owner):
if ebpf is None:
return self
if ebpf.loaded:
data = ebpf.__dict__[self.map.name].data[
self.position:self.position + self.size]
return unpack(self.fmt, data)[0]
return Memory(ebpf, Memory.bits_to_opcode[self.size * 8],
ebpf.r0 + self.position, self.signed)
def __set_name__(self, owner, name):
self.name = name
def __set__(self, ebpf, value):
if ebpf.loaded:
ebpf.__dict__[self.map.name].data[
self.position:self.position + self.size] = \
pack(self.fmt, value)
else:
getattr(ebpf, f"m{self.size * 8}")[ebpf.r0 + self.position] = value
class ArrayMapAccess:
def __init__(self, fd, size):
self.fd = fd
self.size = size
def read(self):
self.data = bpf.lookup_elem(self.fd, b"\0\0\0\0", self.size)
def write(self):
bpf.update_elem(self.fd, b"\0\0\0\0", self.data, 0)
class ArrayMap(Map):
position = 0
def __init__(self):
self.vars = []
def globalVar(self, signed=False, size=4):
ret = ArrayGlobalVarDesc(self, self.position, size, signed)
self.position = (self.position + 2 * size - 1) & -size
self.vars.append(ret)
return ret
def __set_name__(self, owner, name):
self.name = name
def init(self, ebpf):
fd = bpf.create_map(2, 4, self.position, 1)
setattr(ebpf, self.name, ArrayMapAccess(fd, self.position))
with ebpf.save_registers(None), ebpf.get_stack(4) as stack:
ebpf.append(Opcode.ST, 10, 0, stack, 0)
ebpf.r1 = ebpf.get_fd(fd)
ebpf.r2 = ebpf.r10 + stack
ebpf.call(1)
with ebpf.If(ebpf.r0 == 0):
ebpf.exit()
ebpf.owners.add(0)
class PseudoFd(Expression):
def __init__(self, ebpf, fd):
self.ebpf = ebpf
......
......@@ -2,7 +2,8 @@ from unittest import TestCase, main
from . import ebpf
from .ebpf import (
AssembleError, EBPF, HashMap, Opcode, OpcodeFlags, Opcode as O, LocalVar)
ArrayMap, AssembleError, EBPF, HashMap, Opcode, OpcodeFlags,
Opcode as O, LocalVar)
from .bpf import ProgType, prog_test_run
......@@ -487,6 +488,24 @@ class KernelTests(TestCase):
prog_test_run(fd, 1000, 1000, 0, 0, 1)
self.assertEqual(e.a, 31)
def test_arraymap(self):
class Global(EBPF):
map = ArrayMap()
a = map.globalVar()
e = Global(ProgType.XDP, "GPL")
e.a += 7
e.exit()
fd = e.load()
prog_test_run(fd, 1000, 1000, 0, 0, 1)
e.map.read()
e.a *= 2
e.map.write()
prog_test_run(fd, 1000, 1000, 0, 0, 1)
e.map.read()
self.assertEqual(e.a, 21)
def test_minimal(self):
class Global(EBPF):
map = HashMap()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment