diff --git a/ebpf.py b/ebpf.py index 8e36a9b8594c8cb9d85b758a160492f4b1e38c4c..ab6d1dd51ddacc6bf2c13024e83f757eef425dfe 100644 --- a/ebpf.py +++ b/ebpf.py @@ -1,8 +1,9 @@ from collections import namedtuple +from contextlib import contextmanager from struct import pack from enum import Enum -from .bpf import prog_load +from .bpf import create_map, prog_load Instruction = namedtuple("Instruction", ["opcode", "dst", "src", "off", "imm"]) @@ -556,6 +557,71 @@ class MemoryDesc: return Memory(self.ebpf, self.bits, addr) +class GlobalVar(Expression): + def __init__(self, count, hashMap, signed): + self.count = count + self.hashMap = hashMap + self.signed = signed + + def __get__(self, instance, owner): + return self + + def __set_name__(self, owner, name): + self.name = name + + def calculate(self, dst, long, signed, force): + with self.ebpf.save_registers(dst), self.ebpf.get_stack(4) as stack: + self.ebpf.append(Opcode.ST, 10, 0, stack, self.count) + self.ebpf.r1 = self.ebpf.get_fd(self.fd) + self.ebpf.r2 = self.ebpf.r10 + stack + self.ebpf.call(1) + with self.ebpf.If(self.ebpf.r0 == 0): + self.ebpf.exit() + if dst is None: + dst = self.ebpf.get_free_register() + free = True + else: + free = False + self.ebpf.append(Opcode.LD, dst, 0, 0, 0) + return dst, False, self.signed, free + + def __set__(self, ebpf, value): + with ebpf.get_stack(8) as stack: + src, _, _, free = value.calculate(None, False, self.signed) + ebpf.append(Opcode.STX, 10, src, stack + 4, 0) + if free: + ebpf.owners.discard(src) + ebpf.append(Opcode.ST, 10, 0, stack, self.count) + with ebpf.save_registers(None): + ebpf.r1 = self.ebpf.get_fd(self.fd) + ebpf.r2 = self.ebpf.r10 + stack + ebpf.r3 = self.ebpf.r10 + (stack + 4) + ebpf.r4 = 3 + self.ebpf.call(2) + + +class HashMap: + count = 0 + + def __init__(self): + self.vars = [] + + def globalVar(self, signed=False): + self.count += 1 + ret = GlobalVar(self.count, self, signed) + self.vars.append(ret) + return ret + + def __set_name__(self, owner, name): + owner._add_init_hook(self._init) + + def _init(self, ebpf): + fd = create_map(1, 1, 4, self.count) + for v in self.vars: + var = getattr(ebpf, v.name) + var.ebpf = ebpf + var.fd = fd + class PseudoFd(Expression): def __init__(self, ebpf, fd): self.ebpf = ebpf @@ -612,6 +678,10 @@ class EBPF: self.owners = {1, 10} + if self._init_hooks is not None: + for hook in self._init_hooks: + hook(self) + def append(self, opcode, dst, src, off, imm): self.opcodes.append(Instruction(opcode, dst, src, off, imm)) @@ -666,6 +736,37 @@ class EBPF: self.append(Opcode.DW, no, 0, 0, value & 0xffffffff) self.append(Opcode.W, 0, 0, 0, value >> 32) + @contextmanager + def save_registers(self, dst): + oldowners = self.owners.copy() + self.owners |= set(range(6)) + save = [] + for i in range(5): + if i in oldowners and i != dst: + tmp = self.get_free_register() + self.owners.add(tmp) + self.append(Opcode.MOV+Opcode.LONG+Opcode.REG, tmp, i, 0, 0) + save.append((tmp, i)) + yield + for tmp, i in save: + self.append(Opcode.MOV+Opcode.LONG+Opcode.REG, i, tmp, 0, 0) + self.owners = oldowners + + @contextmanager + def get_stack(self, size): + oldstack = self.stack + self.stack = (self.stack - size) & -size + yield self.stack + self.stack = oldstack + + _init_hooks = None + + @classmethod + def _add_init_hook(cls, hook): + if cls._init_hooks is None: + cls._init_hooks = [] + cls._init_hooks.append(hook) + for i in range(11): setattr(EBPF, f"r{i}", RegisterDesc(i, True)) diff --git a/ebpf_test.py b/ebpf_test.py index 1ab6791b65a66bfae9ab93e21b107f9283c16961..66291de076c2c0d0909ecdd7fffd3285a7a59546 100644 --- a/ebpf_test.py +++ b/ebpf_test.py @@ -2,7 +2,7 @@ from unittest import TestCase, main from . import ebpf from .ebpf import ( - AssembleError, EBPF, Opcode, OpcodeFlags, Opcode as O, LocalVar) + AssembleError, EBPF, HashMap, Opcode, OpcodeFlags, Opcode as O, LocalVar) from .bpf import ProgType @@ -469,12 +469,14 @@ class Tests(TestCase): class KernelTests(TestCase): def test_minimal(self): - e = EBPF(ProgType.XDP, "GPL") - with e.If((e.r1 == 0x111111) & (e.r10 == 0x22222)) as cond: - e.r0 = 333333 - with cond.Else(): - e.r0 = 444444 + class Global(EBPF): + map = HashMap() + a = map.globalVar() + + e = Global(ProgType.XDP, "GPL") + e.a += 1 e.exit() + print(e.opcodes) print(e.load(log_level=1)[1]) self.fail()