From dbf16e610e2318dcd3f0a3b2e1422e3663ded36a Mon Sep 17 00:00:00 2001 From: Martin Teichmann <martin.teichmann@xfel.eu> Date: Tue, 29 Dec 2020 08:31:21 +0000 Subject: [PATCH] make HashMap work on user side --- ebpf.py | 68 +++++++++++++++++++++++++++++++++------------------- ebpf_test.py | 17 ++++++++++++- 2 files changed, 59 insertions(+), 26 deletions(-) diff --git a/ebpf.py b/ebpf.py index caae0c9..9238663 100644 --- a/ebpf.py +++ b/ebpf.py @@ -1,9 +1,9 @@ from collections import namedtuple from contextlib import contextmanager, ExitStack -from struct import pack +from struct import pack, unpack from enum import Enum -from .bpf import create_map, prog_load +from . import bpf Instruction = namedtuple("Instruction", ["opcode", "dst", "src", "off", "imm"]) @@ -522,15 +522,19 @@ class HashGlobalVar(Expression): yield dst, False, self.signed - class HashGlobalVarDesc: - def __init__(self, count, signed): + def __init__(self, count, signed, default=0): self.count = count self.signed = signed + self.default = default def __get__(self, instance, owner): if instance is None: return self + if instance.loaded: + fd = instance.__dict__[self.name].fd + ret = bpf.lookup_elem(fd, pack("B", self.count), 4) + return unpack("i" if self.signed else "I", ret)[0] ret = instance.__dict__.get(self.name, None) if ret is None: ret = HashGlobalVar(instance, self.count, self.signed) @@ -541,6 +545,11 @@ class HashGlobalVarDesc: self.name = name def __set__(self, ebpf, value): + if ebpf.loaded: + fd = ebpf.__dict__[self.name].fd + bpf.update_elem(fd, pack("B", self.count), + pack("i" if self.signed else "I", value), 0) + return with ebpf.get_stack(8) as stack: with value.calculate(None, False, self.signed) as (src, _, _): ebpf.append(Opcode.STX, 10, src, stack + 4, 0) @@ -549,30 +558,38 @@ class HashGlobalVarDesc: ebpf.r1 = ebpf.get_fd(ebpf.__dict__[self.name].fd) ebpf.r2 = ebpf.r10 + stack ebpf.r3 = ebpf.r10 + (stack + 4) - ebpf.r4 = 3 + ebpf.r4 = 0 ebpf.call(2) -class HashMap: +class Map: + def init(self, ebpf): + pass + + def load(self, ebpf): + pass + +class HashMap(Map): count = 0 def __init__(self): self.vars = [] - def globalVar(self, signed=False): + def globalVar(self, signed=False, default=0): self.count += 1 - ret = HashGlobalVarDesc(self.count, signed) + ret = HashGlobalVarDesc(self.count, signed, default) self.vars.append(ret) return ret - def __set_name__(self, owner, name): - owner._add_init_hook(self._init) - - def _init(self, ebpf): - fd = create_map(1, 1, 4, self.count) + def init(self, ebpf): + fd = bpf.create_map(1, 1, 4, self.count) for v in self.vars: getattr(ebpf, v.name).fd = fd + def load(self, ebpf): + for v in self.vars: + setattr(ebpf, v.name, ebpf.__class__.__dict__[v.name].default) + class PseudoFd(Expression): def __init__(self, ebpf, fd): @@ -620,6 +637,7 @@ class EBPF: self.prog_type = prog_type self.license = license self.kern_version = kern_version + self.loaded = False self.m8 = MemoryDesc(self, Opcode.B) self.m16 = MemoryDesc(self, Opcode.H) @@ -628,9 +646,9 @@ class EBPF: self.owners = {1, 10} - if self._init_hooks is not None: - for hook in self._init_hooks: - hook(self) + for v in self.__class__.__dict__.values(): + if isinstance(v, Map): + v.init(self) def append(self, opcode, dst, src, off, imm): self.opcodes.append(Instruction(opcode, dst, src, off, imm)) @@ -642,8 +660,15 @@ class EBPF: for i in self.opcodes) def load(self, log_level=0, log_size=4096): - return prog_load(self.prog_type, self.assemble(), self.license, - log_level, log_size, self.kern_version) + ret = bpf.prog_load(self.prog_type, self.assemble(), self.license, + log_level, log_size, self.kern_version) + self.loaded = True + + for v in self.__class__.__dict__.values(): + if isinstance(v, Map): + v.load(self) + + return ret def jumpIf(self, comp): comp.compare(False) @@ -717,13 +742,6 @@ class EBPF: yield self.stack self.stack = oldstack - _init_hooks = None - - @classmethod - def _add_init_hook(cls, hook): - if cls._init_hooks is None: - cls._init_hooks = [] - cls._init_hooks.append(hook) for i in range(11): setattr(EBPF, f"r{i}", RegisterDesc(i, True)) diff --git a/ebpf_test.py b/ebpf_test.py index 3ee19d8..5484ac3 100644 --- a/ebpf_test.py +++ b/ebpf_test.py @@ -3,7 +3,7 @@ from unittest import TestCase, main from . import ebpf from .ebpf import ( AssembleError, EBPF, HashMap, Opcode, OpcodeFlags, Opcode as O, LocalVar) -from .bpf import ProgType +from .bpf import ProgType, prog_test_run opcodes = list((v.value, v) for v in Opcode) @@ -472,6 +472,21 @@ class Tests(TestCase): class KernelTests(TestCase): + def test_hashmap(self): + class Global(EBPF): + map = HashMap() + a = map.globalVar(default=5) + + e = Global(ProgType.XDP, "GPL") + e.a += 7 + e.exit() + + fd = e.load() + prog_test_run(fd, 1000, 1000, 0, 0, 1) + e.a *= 2 + prog_test_run(fd, 1000, 1000, 0, 0, 1) + self.assertEqual(e.a, 31) + def test_minimal(self): class Global(EBPF): map = HashMap() -- GitLab