Skip to content
Snippets Groups Projects
Commit 5f33f033 authored by Martin Teichmann's avatar Martin Teichmann
Browse files

first starts for global variables

parent a33f2efb
No related branches found
No related tags found
No related merge requests found
from collections import namedtuple from collections import namedtuple
from contextlib import contextmanager
from struct import pack from struct import pack
from enum import Enum from enum import Enum
from .bpf import prog_load from .bpf import create_map, prog_load
Instruction = namedtuple("Instruction", Instruction = namedtuple("Instruction",
["opcode", "dst", "src", "off", "imm"]) ["opcode", "dst", "src", "off", "imm"])
...@@ -556,6 +557,71 @@ class MemoryDesc: ...@@ -556,6 +557,71 @@ class MemoryDesc:
return Memory(self.ebpf, self.bits, addr) return Memory(self.ebpf, self.bits, addr)
class GlobalVar(Expression):
def __init__(self, count, hashMap, signed):
self.count = count
self.hashMap = hashMap
self.signed = signed
def __get__(self, instance, owner):
return self
def __set_name__(self, owner, name):
self.name = name
def calculate(self, dst, long, signed, force):
with self.ebpf.save_registers(dst), self.ebpf.get_stack(4) as stack:
self.ebpf.append(Opcode.ST, 10, 0, stack, self.count)
self.ebpf.r1 = self.ebpf.get_fd(self.fd)
self.ebpf.r2 = self.ebpf.r10 + stack
self.ebpf.call(1)
with self.ebpf.If(self.ebpf.r0 == 0):
self.ebpf.exit()
if dst is None:
dst = self.ebpf.get_free_register()
free = True
else:
free = False
self.ebpf.append(Opcode.LD, dst, 0, 0, 0)
return dst, False, self.signed, free
def __set__(self, ebpf, value):
with ebpf.get_stack(8) as stack:
src, _, _, free = value.calculate(None, False, self.signed)
ebpf.append(Opcode.STX, 10, src, stack + 4, 0)
if free:
ebpf.owners.discard(src)
ebpf.append(Opcode.ST, 10, 0, stack, self.count)
with ebpf.save_registers(None):
ebpf.r1 = self.ebpf.get_fd(self.fd)
ebpf.r2 = self.ebpf.r10 + stack
ebpf.r3 = self.ebpf.r10 + (stack + 4)
ebpf.r4 = 3
self.ebpf.call(2)
class HashMap:
count = 0
def __init__(self):
self.vars = []
def globalVar(self, signed=False):
self.count += 1
ret = GlobalVar(self.count, self, signed)
self.vars.append(ret)
return ret
def __set_name__(self, owner, name):
owner._add_init_hook(self._init)
def _init(self, ebpf):
fd = create_map(1, 1, 4, self.count)
for v in self.vars:
var = getattr(ebpf, v.name)
var.ebpf = ebpf
var.fd = fd
class PseudoFd(Expression): class PseudoFd(Expression):
def __init__(self, ebpf, fd): def __init__(self, ebpf, fd):
self.ebpf = ebpf self.ebpf = ebpf
...@@ -612,6 +678,10 @@ class EBPF: ...@@ -612,6 +678,10 @@ class EBPF:
self.owners = {1, 10} self.owners = {1, 10}
if self._init_hooks is not None:
for hook in self._init_hooks:
hook(self)
def append(self, opcode, dst, src, off, imm): def append(self, opcode, dst, src, off, imm):
self.opcodes.append(Instruction(opcode, dst, src, off, imm)) self.opcodes.append(Instruction(opcode, dst, src, off, imm))
...@@ -666,6 +736,37 @@ class EBPF: ...@@ -666,6 +736,37 @@ class EBPF:
self.append(Opcode.DW, no, 0, 0, value & 0xffffffff) self.append(Opcode.DW, no, 0, 0, value & 0xffffffff)
self.append(Opcode.W, 0, 0, 0, value >> 32) self.append(Opcode.W, 0, 0, 0, value >> 32)
@contextmanager
def save_registers(self, dst):
oldowners = self.owners.copy()
self.owners |= set(range(6))
save = []
for i in range(5):
if i in oldowners and i != dst:
tmp = self.get_free_register()
self.owners.add(tmp)
self.append(Opcode.MOV+Opcode.LONG+Opcode.REG, tmp, i, 0, 0)
save.append((tmp, i))
yield
for tmp, i in save:
self.append(Opcode.MOV+Opcode.LONG+Opcode.REG, i, tmp, 0, 0)
self.owners = oldowners
@contextmanager
def get_stack(self, size):
oldstack = self.stack
self.stack = (self.stack - size) & -size
yield self.stack
self.stack = oldstack
_init_hooks = None
@classmethod
def _add_init_hook(cls, hook):
if cls._init_hooks is None:
cls._init_hooks = []
cls._init_hooks.append(hook)
for i in range(11): for i in range(11):
setattr(EBPF, f"r{i}", RegisterDesc(i, True)) setattr(EBPF, f"r{i}", RegisterDesc(i, True))
......
...@@ -2,7 +2,7 @@ from unittest import TestCase, main ...@@ -2,7 +2,7 @@ from unittest import TestCase, main
from . import ebpf from . import ebpf
from .ebpf import ( from .ebpf import (
AssembleError, EBPF, Opcode, OpcodeFlags, Opcode as O, LocalVar) AssembleError, EBPF, HashMap, Opcode, OpcodeFlags, Opcode as O, LocalVar)
from .bpf import ProgType from .bpf import ProgType
...@@ -469,12 +469,14 @@ class Tests(TestCase): ...@@ -469,12 +469,14 @@ class Tests(TestCase):
class KernelTests(TestCase): class KernelTests(TestCase):
def test_minimal(self): def test_minimal(self):
e = EBPF(ProgType.XDP, "GPL") class Global(EBPF):
with e.If((e.r1 == 0x111111) & (e.r10 == 0x22222)) as cond: map = HashMap()
e.r0 = 333333 a = map.globalVar()
with cond.Else():
e.r0 = 444444 e = Global(ProgType.XDP, "GPL")
e.a += 1
e.exit() e.exit()
print(e.opcodes)
print(e.load(log_level=1)[1]) print(e.load(log_level=1)[1])
self.fail() self.fail()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment