diff --git a/ebpf.py b/ebpf.py index 29d432a81b8094963fa7349da0df507ba646c0d2..caae0c9ff866fef7da786987e6fd4d8826ccd0db 100644 --- a/ebpf.py +++ b/ebpf.py @@ -502,18 +502,12 @@ class MemoryDesc: return Memory(self.ebpf, self.bits, addr) -class GlobalVar(Expression): - def __init__(self, count, hashMap, signed): +class HashGlobalVar(Expression): + def __init__(self, ebpf, count, signed): + self.ebpf = ebpf self.count = count - self.hashMap = hashMap self.signed = signed - def __get__(self, instance, owner): - return self - - def __set_name__(self, owner, name): - self.name = name - @contextmanager def calculate(self, dst, long, signed, force): with self.ebpf.save_registers(dst), self.ebpf.get_stack(4) as stack: @@ -527,17 +521,36 @@ class GlobalVar(Expression): self.ebpf.append(Opcode.LD, dst, 0, 0, 0) yield dst, False, self.signed + + +class HashGlobalVarDesc: + def __init__(self, count, signed): + self.count = count + self.signed = signed + + def __get__(self, instance, owner): + if instance is None: + return self + ret = instance.__dict__.get(self.name, None) + if ret is None: + ret = HashGlobalVar(instance, self.count, self.signed) + instance.__dict__[self.name] = ret + return ret + + def __set_name__(self, owner, name): + self.name = name + def __set__(self, ebpf, value): with ebpf.get_stack(8) as stack: with value.calculate(None, False, self.signed) as (src, _, _): ebpf.append(Opcode.STX, 10, src, stack + 4, 0) ebpf.append(Opcode.ST, 10, 0, stack, self.count) with ebpf.save_registers(None): - ebpf.r1 = self.ebpf.get_fd(self.fd) - ebpf.r2 = self.ebpf.r10 + stack - ebpf.r3 = self.ebpf.r10 + (stack + 4) + ebpf.r1 = ebpf.get_fd(ebpf.__dict__[self.name].fd) + ebpf.r2 = ebpf.r10 + stack + ebpf.r3 = ebpf.r10 + (stack + 4) ebpf.r4 = 3 - self.ebpf.call(2) + ebpf.call(2) class HashMap: @@ -548,7 +561,7 @@ class HashMap: def globalVar(self, signed=False): self.count += 1 - ret = GlobalVar(self.count, self, signed) + ret = HashGlobalVarDesc(self.count, signed) self.vars.append(ret) return ret @@ -558,9 +571,8 @@ class HashMap: def _init(self, ebpf): fd = create_map(1, 1, 4, self.count) for v in self.vars: - var = getattr(ebpf, v.name) - var.ebpf = ebpf - var.fd = fd + getattr(ebpf, v.name).fd = fd + class PseudoFd(Expression): def __init__(self, ebpf, fd): @@ -689,9 +701,9 @@ class EBPF: with ExitStack() as exitStack: for i in range(5): if i in oldowners and i != dst: - tmp = self.exitStack.enter_context( - self.get_free_register(None)) - self.append(Opcode.MOV+Opcode.LONG+Opcode.REG, tmp, i, 0, 0) + tmp = exitStack.enter_context(self.get_free_register(None)) + self.append(Opcode.MOV+Opcode.LONG+Opcode.REG, + tmp, i, 0, 0) save.append((tmp, i)) yield for tmp, i in save: