Skip to content
Snippets Groups Projects
Commit d3df05ab authored by Martin Teichmann's avatar Martin Teichmann
Browse files

make local variables work for subprograms

parent 6eb85bb1
No related branches found
No related tags found
No related merge requests found
...@@ -633,27 +633,31 @@ class MemoryDesc: ...@@ -633,27 +633,31 @@ class MemoryDesc:
self.bits = bits self.bits = bits
self.signed = signed self.signed = signed
def __get__(self, ebpf, owner): def __get__(self, instance, owner):
if ebpf is None: if instance is None:
return self return self
elif isinstance(ebpf, SubProgram): elif isinstance(instance, SubProgram):
ebpf = ebpf.ebpf ebpf = instance.ebpf
else:
ebpf = instance
return Memory(ebpf, Memory.bits_to_opcode[self.bits], return Memory(ebpf, Memory.bits_to_opcode[self.bits],
ebpf.r[self.base_register] + self.addr, ebpf.r[self.base_register] + self.addr(instance),
self.signed) self.signed)
def __set__(self, ebpf, value): def __set__(self, instance, value):
if isinstance(ebpf, SubProgram): if isinstance(instance, SubProgram):
ebpf = ebpf.ebpf ebpf = instance.ebpf
else:
ebpf = instance
bits = Memory.bits_to_opcode[self.bits] bits = Memory.bits_to_opcode[self.bits]
if isinstance(value, int): if isinstance(value, int):
ebpf.append(Opcode.ST + bits, self.base_register, 0, ebpf.append(Opcode.ST + bits, self.base_register, 0,
self.addr, value) self.addr(instance), value)
else: else:
with value.calculate(None, self.bits == 64, self.signed) \ with value.calculate(None, self.bits == 64, self.signed) \
as (src, _, _): as (src, _, _):
ebpf.append(Opcode.STX + bits, self.base_register, ebpf.append(Opcode.STX + bits, self.base_register,
src, self.addr, 0) src, self.addr(instance), 0)
class LocalVar(MemoryDesc): class LocalVar(MemoryDesc):
...@@ -663,9 +667,15 @@ class LocalVar(MemoryDesc): ...@@ -663,9 +667,15 @@ class LocalVar(MemoryDesc):
size = int(self.bits // 8) size = int(self.bits // 8)
owner.stack -= size owner.stack -= size
owner.stack &= -size owner.stack &= -size
self.addr = owner.stack self.relative_addr = owner.stack
self.name = name self.name = name
def addr(self, instance):
if isinstance(instance, SubProgram):
return (instance.ebpf.stack & -8) + self.relative_addr
else:
return self.relative_addr
class MemoryMap: class MemoryMap:
def __init__(self, ebpf, bits): def __init__(self, ebpf, bits):
...@@ -945,4 +955,4 @@ for i in range(10): ...@@ -945,4 +955,4 @@ for i in range(10):
class SubProgram: class SubProgram:
pass stack = 0
...@@ -157,6 +157,30 @@ class Tests(TestCase): ...@@ -157,6 +157,30 @@ class Tests(TestCase):
Instruction(opcode=O.REG+O.STX, dst=10, src=0, off=-4, imm=0), Instruction(opcode=O.REG+O.STX, dst=10, src=0, off=-4, imm=0),
Instruction(opcode=O.DW+O.STX, dst=10, src=1, off=-16, imm=0)]) Instruction(opcode=O.DW+O.STX, dst=10, src=1, off=-16, imm=0)])
def test_local_subprog(self):
class Local(EBPF):
a = LocalVar(32, False)
class Sub(SubProgram):
b = LocalVar(32, False)
def program(self):
self.b *= 3
s1 = Sub()
s2 = Sub()
e = Local(ProgType.XDP, "GPL", subprograms=[s1, s2])
e.a = 5
s1.b = 3
e.r3 = s1.b
s2.b = 7
self.assertEqual(e.opcodes, [
Instruction(opcode=O.W+O.ST, dst=10, src=0, off=-4, imm=5),
Instruction(opcode=O.W+O.ST, dst=10, src=0, off=-12, imm=3),
Instruction(opcode=O.W+O.LD, dst=3, src=10, off=-12, imm=0),
Instruction(opcode=O.W+O.ST, dst=10, src=0, off=-12, imm=7)])
def test_jump(self): def test_jump(self):
e = EBPF() e = EBPF()
e.owners = set(range(11)) e.owners = set(range(11))
......
...@@ -4,7 +4,7 @@ from socket import AF_NETLINK, NETLINK_ROUTE, if_nametoindex ...@@ -4,7 +4,7 @@ from socket import AF_NETLINK, NETLINK_ROUTE, if_nametoindex
import socket import socket
from struct import pack, unpack from struct import pack, unpack
from .ebpf import EBPF, Expression, Memory, MemoryDesc, Opcode, Comparison from .ebpf import EBPF, Expression, Memory, Opcode, Comparison
from .bpf import ProgType from .bpf import ProgType
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment