diff --git a/ebpf.py b/ebpf.py index 07eee42bbe255905ca5508cde5ace063b2a6f429..2c39e95d00db89d2de2f8d680d9ed46e77c7df39 100644 --- a/ebpf.py +++ b/ebpf.py @@ -465,10 +465,11 @@ class Register(Expression): class Memory(Expression): - def __init__(self, ebpf, bits, address): + def __init__(self, ebpf, bits, address, signed=False): self.ebpf = ebpf self.bits = bits self.address = address + self.signed = signed def calculate(self, dst, long, signed, force=False): if not long and self.bits == Opcode.DW: @@ -486,12 +487,44 @@ class Memory(Expression): self.ebpf.append(Opcode.LD + self.bits, dst, src, 0, 0) if rfree: self.ebpf.owners.discard(src) - return dst, long, signed, free + return dst, long, self.signed, free def contains(self, no): return self.address.contains(no) +class LocalVar: + bits_to_opcode = {32: Opcode.W, 16: Opcode.H, 8: Opcode.B, 64: Opcode.DW} + + def __init__(self, bits=32, signed=False): + self.bits = bits + self.signed = signed + + def __set_name__(self, owner, name): + size = int(self.bits // 8) + owner.stack -= size + owner.stack &= -size + self.addr = owner.stack + self.name = name + + def __get__(self, instance, owner): + if instance is None: + return self + else: + return Memory(instance, self.bits_to_opcode[self.bits], + instance.r10 + self.addr, self.signed) + + def __set__(self, instance, value): + bits = self.bits_to_opcode[self.bits] + if isinstance(value, int): + instance.append(Opcode.ST + bits, 10, 0, self.addr, value) + else: + src, _, _, free = value.calculate(None, self.bits == 64, + self.signed) + instance.append(Opcode.STX + bits, 10, src, self.addr, 0) + if free: + instance.owners.discard(src) + class MemoryDesc: def __init__(self, ebpf, bits): self.ebpf = ebpf @@ -562,6 +595,8 @@ class RegisterDesc: class EBPF: + stack = 0 + def __init__(self, prog_type=0, license="", kern_version=0): self.opcodes = [] self.prog_type = prog_type diff --git a/ebpf_test.py b/ebpf_test.py index e876e979e5909dd1b2a58ff97760579197aec205..5653185dfb4fe23859a3daca3664ec4f409e7022 100644 --- a/ebpf_test.py +++ b/ebpf_test.py @@ -1,7 +1,8 @@ from unittest import TestCase, main from . import ebpf -from .ebpf import AssembleError, EBPF, Opcode, OpcodeFlags, Opcode as O +from .ebpf import ( + AssembleError, EBPF, Opcode, OpcodeFlags, Opcode as O, LocalVar) from .bpf import ProgType @@ -131,6 +132,24 @@ class Tests(TestCase): Instruction(opcode=97, dst=4, src=8, off=7, imm=0), Instruction(opcode=121, dst=5, src=3, off=-7, imm=0)]) + def test_local(self): + class Local(EBPF): + a = LocalVar(8, True) + b = LocalVar(16, False) + c = LocalVar(32, True) + d = LocalVar(64, False) + + e = Local(ProgType.XDP, "GPL") + e.a = 5 + e.b = e.c >> 3 + e.d = e.r1 + + self.assertEqual(e.opcodes, [ + Instruction(opcode=O.B+O.ST, dst=10, src=0, off=-1, imm=5), + Instruction(opcode=O.W+O.LD, dst=0, src=10, off=-8, imm=0), + Instruction(opcode=O.ARSH, dst=0, src=0, off=0, imm=3), + Instruction(opcode=O.REG+O.STX, dst=10, src=0, off=-4, imm=0), + Instruction(opcode=O.DW+O.STX, dst=10, src=1, off=-16, imm=0)]) def test_jump(self): e = EBPF()