Skip to content
Snippets Groups Projects
Commit 3c78b35b authored by Martin Teichmann's avatar Martin Teichmann
Browse files

generalize get_address

parent 04df2a4b
No related branches found
No related tags found
No related merge requests found
...@@ -268,6 +268,23 @@ class Expression: ...@@ -268,6 +268,23 @@ class Expression:
def __neg__(self): def __neg__(self):
return Negate(self.ebpf, self) return Negate(self.ebpf, self)
@contextmanager
def calculate(self, dst, long, signed, force=False):
with self.ebpf.get_free_register(dst) as dst:
with self.get_address(dst, long, signed):
self.ebpf.append(Opcode.LD, dst, dst, 0, 0)
yield dst, False, self.signed
@contextmanager
def get_address(self, dst, long, signed):
with self.ebpf.get_stack(4 + 4 * long) as stack:
with self.calculate(dst, long, signed) as (src, _, _):
self.ebpf.append(Opcode.STX + Opcode.DW * long,
10, src, stack, 0)
self.ebpf.append(Opcode.MOV + Opcode.LONG + Opcode.REG, dst, 10, 0, 0)
self.ebpf.append(Opcode.ADD + Opcode.LONG, dst, 0, 0, stack)
yield
class Binary(Expression): class Binary(Expression):
def __init__(self, ebpf, left, right, operator): def __init__(self, ebpf, left, right, operator):
...@@ -509,17 +526,21 @@ class HashGlobalVar(Expression): ...@@ -509,17 +526,21 @@ class HashGlobalVar(Expression):
self.signed = signed self.signed = signed
@contextmanager @contextmanager
def calculate(self, dst, long, signed, force): def get_address(self, dst, long, signed):
with self.ebpf.save_registers(dst), self.ebpf.get_stack(4) as stack: if long:
raise AssembleError("HashMap is only for words")
if signed != self.signed:
raise AssembleError("HashMap variable has wrong signedness")
with self.ebpf.save_registers([i for i in range(6) if i != dst]), \
self.ebpf.get_stack(4) as stack:
self.ebpf.append(Opcode.ST, 10, 0, stack, self.count) self.ebpf.append(Opcode.ST, 10, 0, stack, self.count)
self.ebpf.r1 = self.ebpf.get_fd(self.fd) self.ebpf.r1 = self.ebpf.get_fd(self.fd)
self.ebpf.r2 = self.ebpf.r10 + stack self.ebpf.r2 = self.ebpf.r10 + stack
self.ebpf.call(1) self.ebpf.call(1)
with self.ebpf.If(self.ebpf.r0 == 0): with self.ebpf.If(self.ebpf.r0 == 0):
self.ebpf.exit() self.ebpf.exit()
with self.ebpf.get_free_register(dst) as dst: self.ebpf.append(Opcode.MOV + Opcode.LONG + Opcode.REG, dst, 0, 0, 0)
self.ebpf.append(Opcode.LD, dst, 0, 0, 0) yield
yield dst, False, self.signed
class HashGlobalVarDesc: class HashGlobalVarDesc:
...@@ -550,16 +571,15 @@ class HashGlobalVarDesc: ...@@ -550,16 +571,15 @@ class HashGlobalVarDesc:
bpf.update_elem(fd, pack("B", self.count), bpf.update_elem(fd, pack("B", self.count),
pack("i" if self.signed else "I", value), 0) pack("i" if self.signed else "I", value), 0)
return return
with ebpf.get_stack(8) as stack: with ebpf.save_registers([3]):
with value.calculate(None, False, self.signed) as (src, _, _): with value.get_address(3, False, self.signed):
ebpf.append(Opcode.STX, 10, src, stack + 4, 0) with ebpf.save_registers([0, 1, 2, 4, 5]), \
ebpf.append(Opcode.ST, 10, 0, stack, self.count) ebpf.get_stack(4) as stack:
with ebpf.save_registers(None): ebpf.r1 = ebpf.get_fd(ebpf.__dict__[self.name].fd)
ebpf.r1 = ebpf.get_fd(ebpf.__dict__[self.name].fd) ebpf.append(Opcode.ST, 10, 0, stack, self.count)
ebpf.r2 = ebpf.r10 + stack ebpf.r2 = ebpf.r10 + stack
ebpf.r3 = ebpf.r10 + (stack + 4) ebpf.r4 = 0
ebpf.r4 = 0 ebpf.call(2)
ebpf.call(2)
class Map: class Map:
...@@ -653,7 +673,7 @@ class ArrayMap(Map): ...@@ -653,7 +673,7 @@ class ArrayMap(Map):
def init(self, ebpf): def init(self, ebpf):
fd = bpf.create_map(2, 4, self.position, 1) fd = bpf.create_map(2, 4, self.position, 1)
setattr(ebpf, self.name, ArrayMapAccess(fd, self.position)) setattr(ebpf, self.name, ArrayMapAccess(fd, self.position))
with ebpf.save_registers(None), ebpf.get_stack(4) as stack: with ebpf.save_registers(list(range(6))), ebpf.get_stack(4) as stack:
ebpf.append(Opcode.ST, 10, 0, stack, 0) ebpf.append(Opcode.ST, 10, 0, stack, 0)
ebpf.r1 = ebpf.get_fd(fd) ebpf.r1 = ebpf.get_fd(fd)
ebpf.r2 = ebpf.r10 + stack ebpf.r2 = ebpf.r10 + stack
...@@ -791,13 +811,13 @@ class EBPF: ...@@ -791,13 +811,13 @@ class EBPF:
self.append(Opcode.W, 0, 0, 0, value >> 32) self.append(Opcode.W, 0, 0, 0, value >> 32)
@contextmanager @contextmanager
def save_registers(self, dst): def save_registers(self, registers):
oldowners = self.owners.copy() oldowners = self.owners.copy()
self.owners |= set(range(6)) self.owners |= set(registers)
save = [] save = []
with ExitStack() as exitStack: with ExitStack() as exitStack:
for i in range(5): for i in registers:
if i in oldowners and i != dst: if i in oldowners:
tmp = exitStack.enter_context(self.get_free_register(None)) tmp = exitStack.enter_context(self.get_free_register(None))
self.append(Opcode.MOV+Opcode.LONG+Opcode.REG, self.append(Opcode.MOV+Opcode.LONG+Opcode.REG,
tmp, i, 0, 0) tmp, i, 0, 0)
......
...@@ -3,7 +3,7 @@ from unittest import TestCase, main ...@@ -3,7 +3,7 @@ from unittest import TestCase, main
from . import ebpf from . import ebpf
from .ebpf import ( from .ebpf import (
ArrayMap, AssembleError, EBPF, HashMap, Opcode, OpcodeFlags, ArrayMap, AssembleError, EBPF, HashMap, Opcode, OpcodeFlags,
Opcode as O, LocalVar) Opcode as O, LocalVar, XDP)
from .bpf import ProgType, prog_test_run from .bpf import ProgType, prog_test_run
...@@ -477,8 +477,10 @@ class KernelTests(TestCase): ...@@ -477,8 +477,10 @@ class KernelTests(TestCase):
class Global(EBPF): class Global(EBPF):
map = HashMap() map = HashMap()
a = map.globalVar(default=5) a = map.globalVar(default=5)
b = map.globalVar()
e = Global(ProgType.XDP, "GPL") e = Global(ProgType.XDP, "GPL")
e.b = e.a
e.a += 7 e.a += 7
e.exit() e.exit()
...@@ -487,6 +489,7 @@ class KernelTests(TestCase): ...@@ -487,6 +489,7 @@ class KernelTests(TestCase):
e.a *= 2 e.a *= 2
prog_test_run(fd, 1000, 1000, 0, 0, 1) prog_test_run(fd, 1000, 1000, 0, 0, 1)
self.assertEqual(e.a, 31) self.assertEqual(e.a, 31)
self.assertEqual(e.b, 24)
def test_arraymap(self): def test_arraymap(self):
class Global(EBPF): class Global(EBPF):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment