diff --git a/ebpf.py b/ebpf.py index 3f4e9cfef6410ff1c478b17ed7e0476723d36fab..0705b47c50ac1a0f2cd1fce00ff3d278f2286602 100644 --- a/ebpf.py +++ b/ebpf.py @@ -268,6 +268,23 @@ class Expression: def __neg__(self): return Negate(self.ebpf, self) + @contextmanager + def calculate(self, dst, long, signed, force=False): + with self.ebpf.get_free_register(dst) as dst: + with self.get_address(dst, long, signed): + self.ebpf.append(Opcode.LD, dst, dst, 0, 0) + yield dst, False, self.signed + + @contextmanager + def get_address(self, dst, long, signed): + with self.ebpf.get_stack(4 + 4 * long) as stack: + with self.calculate(dst, long, signed) as (src, _, _): + self.ebpf.append(Opcode.STX + Opcode.DW * long, + 10, src, stack, 0) + self.ebpf.append(Opcode.MOV + Opcode.LONG + Opcode.REG, dst, 10, 0, 0) + self.ebpf.append(Opcode.ADD + Opcode.LONG, dst, 0, 0, stack) + yield + class Binary(Expression): def __init__(self, ebpf, left, right, operator): @@ -509,17 +526,21 @@ class HashGlobalVar(Expression): self.signed = signed @contextmanager - def calculate(self, dst, long, signed, force): - with self.ebpf.save_registers(dst), self.ebpf.get_stack(4) as stack: + def get_address(self, dst, long, signed): + if long: + raise AssembleError("HashMap is only for words") + if signed != self.signed: + raise AssembleError("HashMap variable has wrong signedness") + with self.ebpf.save_registers([i for i in range(6) if i != dst]), \ + self.ebpf.get_stack(4) as stack: self.ebpf.append(Opcode.ST, 10, 0, stack, self.count) self.ebpf.r1 = self.ebpf.get_fd(self.fd) self.ebpf.r2 = self.ebpf.r10 + stack self.ebpf.call(1) with self.ebpf.If(self.ebpf.r0 == 0): self.ebpf.exit() - with self.ebpf.get_free_register(dst) as dst: - self.ebpf.append(Opcode.LD, dst, 0, 0, 0) - yield dst, False, self.signed + self.ebpf.append(Opcode.MOV + Opcode.LONG + Opcode.REG, dst, 0, 0, 0) + yield class HashGlobalVarDesc: @@ -550,16 +571,15 @@ class HashGlobalVarDesc: bpf.update_elem(fd, pack("B", self.count), pack("i" if self.signed else "I", value), 0) return - with ebpf.get_stack(8) as stack: - with value.calculate(None, False, self.signed) as (src, _, _): - ebpf.append(Opcode.STX, 10, src, stack + 4, 0) - ebpf.append(Opcode.ST, 10, 0, stack, self.count) - with ebpf.save_registers(None): - ebpf.r1 = ebpf.get_fd(ebpf.__dict__[self.name].fd) - ebpf.r2 = ebpf.r10 + stack - ebpf.r3 = ebpf.r10 + (stack + 4) - ebpf.r4 = 0 - ebpf.call(2) + with ebpf.save_registers([3]): + with value.get_address(3, False, self.signed): + with ebpf.save_registers([0, 1, 2, 4, 5]), \ + ebpf.get_stack(4) as stack: + ebpf.r1 = ebpf.get_fd(ebpf.__dict__[self.name].fd) + ebpf.append(Opcode.ST, 10, 0, stack, self.count) + ebpf.r2 = ebpf.r10 + stack + ebpf.r4 = 0 + ebpf.call(2) class Map: @@ -653,7 +673,7 @@ class ArrayMap(Map): def init(self, ebpf): fd = bpf.create_map(2, 4, self.position, 1) setattr(ebpf, self.name, ArrayMapAccess(fd, self.position)) - with ebpf.save_registers(None), ebpf.get_stack(4) as stack: + with ebpf.save_registers(list(range(6))), ebpf.get_stack(4) as stack: ebpf.append(Opcode.ST, 10, 0, stack, 0) ebpf.r1 = ebpf.get_fd(fd) ebpf.r2 = ebpf.r10 + stack @@ -791,13 +811,13 @@ class EBPF: self.append(Opcode.W, 0, 0, 0, value >> 32) @contextmanager - def save_registers(self, dst): + def save_registers(self, registers): oldowners = self.owners.copy() - self.owners |= set(range(6)) + self.owners |= set(registers) save = [] with ExitStack() as exitStack: - for i in range(5): - if i in oldowners and i != dst: + for i in registers: + if i in oldowners: tmp = exitStack.enter_context(self.get_free_register(None)) self.append(Opcode.MOV+Opcode.LONG+Opcode.REG, tmp, i, 0, 0) diff --git a/ebpf_test.py b/ebpf_test.py index 471f783722ba08e1f5041470f5ce8bb500b390e9..bdd085d84cbf7d95cb54d620320345099f2aed58 100644 --- a/ebpf_test.py +++ b/ebpf_test.py @@ -3,7 +3,7 @@ from unittest import TestCase, main from . import ebpf from .ebpf import ( ArrayMap, AssembleError, EBPF, HashMap, Opcode, OpcodeFlags, - Opcode as O, LocalVar) + Opcode as O, LocalVar, XDP) from .bpf import ProgType, prog_test_run @@ -477,8 +477,10 @@ class KernelTests(TestCase): class Global(EBPF): map = HashMap() a = map.globalVar(default=5) + b = map.globalVar() e = Global(ProgType.XDP, "GPL") + e.b = e.a e.a += 7 e.exit() @@ -487,6 +489,7 @@ class KernelTests(TestCase): e.a *= 2 prog_test_run(fd, 1000, 1000, 0, 0, 1) self.assertEqual(e.a, 31) + self.assertEqual(e.b, 24) def test_arraymap(self): class Global(EBPF):