diff --git a/ebpfcat/arraymap.py b/ebpfcat/arraymap.py index b20934c2de3e6f797ebb18ab6f03e00a88e73850..05fadec4644753e0306480a241fc69a3a12b7334 100644 --- a/ebpfcat/arraymap.py +++ b/ebpfcat/arraymap.py @@ -19,7 +19,7 @@ from itertools import chain from mmap import mmap from struct import pack_into, unpack_from, calcsize -from .ebpf import FuncId, Map, MemoryDesc, Opcode, SubProgram +from .ebpf import Expression, FuncId, Map, MemoryDesc, Opcode, SubProgram from .bpf import create_map, lookup_elem, MapType, MapFlags, update_elem @@ -29,6 +29,7 @@ class ArrayGlobalVarDesc(MemoryDesc): def __init__(self, map, fmt): self.map = map self.fmt = fmt + self.fixed = fmt == "f" def fmt_addr(self, ebpf): return self.fmt, ebpf.__dict__[self.name] @@ -42,7 +43,10 @@ class ArrayGlobalVarDesc(MemoryDesc): if instance.ebpf.loaded: fmt, addr = self.fmt_addr(instance) data = instance.ebpf.__dict__[self.map.name].data - ret = unpack_from(fmt, data, addr) + if fmt == "f": + return unpack_from("q", data, addr)[0] / Expression.FIXED_BASE + else: + ret = unpack_from(fmt, data, addr) if len(ret) == 1: return ret[0] else: @@ -53,6 +57,9 @@ class ArrayGlobalVarDesc(MemoryDesc): def __set__(self, instance, value): if instance.ebpf.loaded: fmt, addr = self.fmt_addr(instance) + if fmt == "f": + fmt = "q" + value = int(value * Expression.FIXED_BASE) if not isinstance(value, tuple): value = value, pack_into(fmt, instance.ebpf.__dict__[self.map.name].data, @@ -80,7 +87,8 @@ class ArrayMap(Map): for prog in chain([ebpf], ebpf.subprograms): for k, v in prog.__class__.__dict__.items(): if isinstance(v, ArrayGlobalVarDesc): - collection.append((calcsize(v.fmt), prog, k)) + collection.append((8 if v.fmt == "f" else calcsize(v.fmt), + prog, k)) collection.sort(key=lambda t: t[0], reverse=True) position = 0 for size, prog, name in collection: diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py index 7e5457d20c21583e40b709d74f51179087c33618..0750830d9b90e069099c75fdd956a368bc8555df 100644 --- a/ebpfcat/ebpf.py +++ b/ebpfcat/ebpf.py @@ -18,6 +18,7 @@ from abc import ABC, abstractmethod from collections import namedtuple from contextlib import contextmanager, ExitStack +from operator import index from struct import pack, unpack, calcsize from enum import Enum @@ -268,8 +269,18 @@ class AssembleError(Exception): def comparison(uposop, unegop, sposop, snegop): def ret(self, value): - return SimpleComparison(self.ebpf, self, value, - (uposop, unegop, sposop, snegop)) + valuefixed, value = fixedvalue(value) + myself = self + if self.fixed != valuefixed: + if self.fixed: + value = value * self.FIXED_BASE + else: + myself = self * self.FIXED_BASE + + if self.signed or issigned(value): + return SimpleComparison(self.ebpf, myself, value, (sposop, snegop)) + else: + return SimpleComparison(self.ebpf, myself, value, (uposop, unegop)) return ret @@ -356,16 +367,14 @@ class SimpleComparison(Comparison): self.opcode = opcode def compare(self, negative): - with self.left.calculate(None, None, None) as (self.dst, _, lsigned): + with self.left.calculate(None, None) as (self.dst, _): with ExitStack() as exitStack: - if isinstance(self.right, int): - rsigned = (self.right < 0) - else: - self.src, _, rsigned = exitStack.enter_context( - self.right.calculate(None, None, None)) + if isinstance(self.right, Expression): + self.src, _ = exitStack.enter_context( + self.right.calculate(None, None)) self.origin = len(self.ebpf.opcodes) self.ebpf.opcodes.append(None) - self.opcode = self.opcode[negative + 2 * (lsigned or rsigned)] + self.opcode = self.opcode[negative] self.owners = self.ebpf.owners.copy() def target(self, retarget=False): @@ -373,14 +382,14 @@ class SimpleComparison(Comparison): if self.opcode == Opcode.JMP: inst = Instruction(Opcode.JMP, 0, 0, len(self.ebpf.opcodes) - self.origin - 1, 0) - elif isinstance(self.right, int): - inst = Instruction( - self.opcode, self.dst, 0, - len(self.ebpf.opcodes) - self.origin - 1, self.right) - else: + elif isinstance(self.right, Expression): inst = Instruction( self.opcode + Opcode.REG, self.dst, self.src, len(self.ebpf.opcodes) - self.origin - 1, 0) + else: + inst = Instruction( + self.opcode, self.dst, 0, + len(self.ebpf.opcodes) - self.origin - 1, self.right) self.ebpf.opcodes[self.origin] = inst if not retarget: self.ebpf.owners, self.owners = \ @@ -422,46 +431,137 @@ class InvertComparison(Comparison): self.value.target(retarget) -def binary(opcode): - def ret(self, value): - return Binary(self.ebpf, self, value, opcode) - return ret +def issigned(value): + if isinstance(value, Expression): + return value.signed + else: + return value < 0 -def rbinary(opcode): - def ret(self, value): - return ReverseBinary(self.ebpf, value, self, opcode) - return ret + +def fixedvalue(value): + try: + return False, index(value) + except TypeError: + try: + return True, int(float(value) * Expression.FIXED_BASE) + except TypeError: + return value.fixed, value class Expression: """the base class for all numerical expressions""" - __radd__ = __add__ = binary(Opcode.ADD) - __sub__ = binary(Opcode.SUB) - __rsub__ = rbinary(Opcode.SUB) - __rmul__ = __mul__ = binary(Opcode.MUL) - __truediv__ = binary(Opcode.DIV) - __rtruediv__ = rbinary(Opcode.DIV) - __ror__ = __or__ = binary(Opcode.OR) - __lshift__ = binary(Opcode.LSH) - __rlshift__ = rbinary(Opcode.LSH) - __rshift__ = binary(Opcode.RSH) - __rrshift__ = rbinary(Opcode.RSH) - __mod__ = binary(Opcode.MOD) - __rmod__ = rbinary(Opcode.MOD) - __rxor__ = __xor__ = binary(Opcode.XOR) + + FIXED_BASE = 100000 + + def _binary(self, value, opcode): + return Binary(self.ebpf, self, value, opcode, + self.signed or issigned(value), False) + + __ror__ = __or__ = lambda self, value: self._binary(value, Opcode.OR) + __lshift__ = lambda self, value: self._binary(value, Opcode.LSH) + __rxor__ = __xor__ = lambda self, value: self._binary(value, Opcode.XOR) __gt__ = comparison(Opcode.JGT, Opcode.JLE, Opcode.JSGT, Opcode.JSLE) __ge__ = comparison(Opcode.JGE, Opcode.JLT, Opcode.JSGE, Opcode.JSLT) __lt__ = comparison(Opcode.JLT, Opcode.JGE, Opcode.JSLT, Opcode.JSGE) __le__ = comparison(Opcode.JLE, Opcode.JGT, Opcode.JSLE, Opcode.JSGT) + def _sum(self, value, opcode): + valuefixed, value = fixedvalue(value) + myself = self + if self.fixed != valuefixed: + if self.fixed: + value = value * self.FIXED_BASE + else: + myself = self * self.FIXED_BASE + + return Binary(self.ebpf, myself, value, opcode, + self.signed or issigned(value), self.fixed or valuefixed) + + def _rsum(self, value, opcode): + valuefixed, value = fixedvalue(value) + myself = self + if self.fixed != valuefixed: + if self.fixed: + value = value * self.FIXED_BASE + else: + myself = self * self.FIXED_BASE + + return ReverseBinary( + self.ebpf, value, myself, opcode, + self.signed or issigned(value), self.fixed or valuefixed) + + __radd__ = __add__ = lambda self, value: self._sum(value, Opcode.ADD) + __sub__ = lambda self, value: self._sum(value, Opcode.SUB) + __rsub__ = lambda self, value: self._rsum(value, Opcode.SUB) + __mod__ = lambda self, value: self._sum(value, Opcode.MOD) + __rmod__ = lambda self, value: self._rsum(value, Opcode.MOD) + + def __mul__(self, value): + valuefixed, value = fixedvalue(value) + ret = Binary(self.ebpf, self, value, Opcode.MUL, + self.signed or issigned(value), self.fixed or valuefixed) + if self.fixed and valuefixed: + ret = ret / self.FIXED_BASE + return ret + __rmul__ = __mul__ + + def __truediv__(self, value): + valuefixed, value = fixedvalue(value) + myself = self + if not self.fixed and valuefixed: + myself = myself * self.FIXED_BASE ** 2 + elif self.fixed == valuefixed: + myself = myself * self.FIXED_BASE + + return Binary(self.ebpf, myself, value, Opcode.DIV, + self.signed or issigned(value), True) + + def __rtruediv__(self, value): + if self.fixed: + value = int(value * self.FIXED_BASE ** 2) + else: + value = int(value * self.FIXED_BASE) + return ReverseBinary(self.ebpf, value, self, Opcode.DIV, + self.signed or issigned(value), True) + + def __floordiv__(self, value): + valuefixed, value = fixedvalue(value) + myself = self + if not self.fixed and valuefixed: + myself = myself * self.FIXED_BASE + elif self.fixed and not valuefixed: + value = value * self.FIXED_BASE + + return Binary(self.ebpf, myself, value, Opcode.DIV, + self.signed or issigned(value), False) + + def __rfloordiv__(self, value): + if self.fixed: + value = int(value * self.FIXED_BASE) + else: + value = int(value) + return ReverseBinary(self.ebpf, value, self, Opcode.DIV, + self.signed or issigned(value), False) + + def __rshift__(self, value): + opcode = Opcode.ARSH if self.signed else Opcode.RSH + return Binary(self.ebpf, self, value, opcode, self.signed, False) + + def __rrshift__(self, value): + opcode = Opcode.ARSH if value < 0 else Opcode.RSH + return ReverseBinary(self.ebpf, value, self, opcode, value < 0, False) + + def __rlshift__(self, value): + return ReverseBinary(self.ebpf, value, self, Opcode.LSH, + value < 0, False) + def __and__(self, value): return AndExpression(self.ebpf, self, value) def __ne__(self, value): - return SimpleComparison( - self.ebpf, self, value, - (Opcode.JNE, Opcode.JEQ, Opcode.JNE, Opcode.JEQ)) + return SimpleComparison(self.ebpf, self, value, + (Opcode.JNE, Opcode.JEQ)) def __eq__(self, value): return ~(self != value) @@ -486,15 +586,13 @@ class Expression: return self.as_comparison.__exit__(exc_type, exc, tb) @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): """issue the code that calculates the value of this expression this method returns three values: - the number of the register with the result - a boolean indicating whether this is a 64 bit value - - and a booleand indicating whether the result is to be - considered signed. this method is a contextmanager to be used in a `with` statement. At the end of the `with` block the result is @@ -508,19 +606,17 @@ class Expression: or `None` if that does not matter. :param long: True if the result is supposed to be 64 bit. None if it does not matter. - :param signed: True if the result should be considered signed. - None if it does not matter. :param force: if true, `dst` must be respected, otherwise this is optional. """ with self.ebpf.get_free_register(dst) as dst: - with self.get_address(dst, long, signed) as (src, fmt): + with self.get_address(dst, long) as (src, fmt): self.ebpf.append(Opcode.LD + Memory.fmt_to_opcode[fmt], dst, src, 0, 0) - yield dst, long, self.signed + yield dst, long @contextmanager - def get_address(self, dst, long, signed, force=False): + def get_address(self, dst, long, force=False): """get the address of the value of this expression this method returns the address of the result of this expression, @@ -529,7 +625,7 @@ class Expression: the stack. """ with self.ebpf.get_stack(4 + 4 * long) as stack: - with self.calculate(dst, long, signed) as (src, _, _): + with self.calculate(dst, long) as (src, _): self.ebpf.append(Opcode.STX + Opcode.DW * long, 10, src, stack, 0) self.ebpf.append(Opcode.MOV + Opcode.LONG + Opcode.REG, @@ -544,70 +640,67 @@ class Expression: class Binary(Expression): """represent all binary expressions""" - def __init__(self, ebpf, left, right, operator): + def __init__(self, ebpf, left, right, operator, signed, fixed): self.ebpf = ebpf self.left = left self.right = right self.operator = operator + self.signed = signed + self.fixed = fixed @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): orig_dst = dst - if not isinstance(self.right, int) and self.right.contains(dst): + if isinstance(self.right, Expression) and self.right.contains(dst): dst = None with self.ebpf.get_free_register(dst) as dst: - with self.left.calculate(dst, long, signed, True) \ - as (dst, l_long, l_signed): + with self.left.calculate(dst, long, True) as (dst, l_long): if long is None: long = l_long - signed = signed or l_signed - if self.operator is Opcode.RSH and signed: # >>= - operator = Opcode.ARSH - else: - operator = self.operator - if isinstance(self.right, int): - r_signed = self.right < 0 - self.ebpf.append(operator + Opcode.LONG * long, - dst, 0, 0, self.right) - else: - with self.right.calculate(None, long, None) as \ - (src, r_long, r_signed): + if isinstance(self.right, Expression): + with self.right.calculate(None, long) as (src, r_long): self.ebpf.append( - operator + Opcode.REG + self.operator + Opcode.REG + Opcode.LONG * ((r_long or l_long) if long is None else long), dst, src, 0, 0) + elif -0x80000000 <= self.right < 0x100000000: + self.ebpf.append(self.operator + Opcode.LONG * long, + dst, 0, 0, self.right) + else: + with self.ebpf.get_free_register(None) as src: + self.ebpf._load_value(src, self.right) + self.ebpf.append( + self.operator + Opcode.REG + Opcode.LONG, + dst, src, 0, 0) if orig_dst is None or orig_dst == dst: - yield dst, long, signed or r_signed + yield dst, long return self.ebpf.append(Opcode.MOV + Opcode.REG + Opcode.LONG * long, orig_dst, dst, 0, 0) - yield orig_dst, long, signed or r_signed + yield orig_dst, long def contains(self, no): - return self.left.contains(no) or (not isinstance(self.right, int) + return self.left.contains(no) or (isinstance(self.right, Expression) and self.right.contains(no)) class ReverseBinary(Expression): - def __init__(self, ebpf, left, right, operator): + def __init__(self, ebpf, left, right, operator, signed, fixed): self.ebpf = ebpf self.left = left self.right = right self.operator = operator + self.signed = signed + self.fixed = fixed @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): with self.ebpf.get_free_register(dst) as dst: self.ebpf._load_value(dst, self.left) - if self.operator is Opcode.RSH and self.left < 0: # >>= - operator = Opcode.ARSH - else: - operator = self.operator - - with self.right.calculate(None, long, None) as (src, long, _): - self.ebpf.append(operator + Opcode.LONG * long + Opcode.REG, - dst, src, 0, 0) - yield dst, long, signed + with self.right.calculate(None, long) as (src, long): + self.ebpf.append(self.operator + Opcode.LONG * long + + Opcode.REG, dst, src, 0, 0) + yield dst, long def contains(self, no): return self.right.contains(no) @@ -617,13 +710,14 @@ class Negate(Expression): def __init__(self, ebpf, arg): self.ebpf = ebpf self.arg = arg + self.signed = True + self.fixed = arg.fixed @contextmanager - def calculate(self, dst, long, signed, force=False): - with self.arg.calculate(dst, long, signed, force) as \ - (dst, long, signed): + def calculate(self, dst, long, force=False): + with self.arg.calculate(dst, long, force) as (dst, long): self.ebpf.append(Opcode.NEG + Opcode.LONG * long, dst, 0, 0, 0) - yield dst, long, signed + yield dst, long def contains(self, no): return self.arg.contains(no) @@ -633,14 +727,14 @@ class Absolute(Expression): def __init__(self, ebpf, arg): self.ebpf = ebpf self.arg = arg + self.fixed = arg.fixed @contextmanager - def calculate(self, dst, long, signed, force=False): - with self.arg.calculate(dst, long, True, force) as \ - (dst, long, signed): - with self.ebpf.r[dst] < 0: - self.ebpf.r[dst] = -self.ebpf.r[dst] - yield dst, long, True + def calculate(self, dst, long, force=False): + with self.arg.calculate(dst, long, force) as (dst, long): + with self.ebpf.sr[dst] < 0: + self.ebpf.sr[dst] = -self.ebpf.sr[dst] + yield dst, long def contains(self, no): return self.arg.contains(no) @@ -652,27 +746,27 @@ class Sum(Binary): this is used to optimize memory addressing code. """ def __init__(self, ebpf, left, right): - super().__init__(ebpf, left, right, Opcode.ADD) + super().__init__(ebpf, left, right, Opcode.ADD, right < 0, False) def __add__(self, value): - if isinstance(value, int): - return Sum(self.ebpf, self.left, self.right + value) - else: + try: + return Sum(self.ebpf, self.left, self.right + index(value)) + except TypeError: return super().__add__(value) __radd__ = __add__ def __sub__(self, value): - if isinstance(value, int): - return Sum(self.ebpf, self.left, self.right - value) - else: - return super().__sub__(value) + try: + return Sum(self.ebpf, self.left, self.right - index(value)) + except TypeError: + return super().__add__(value) class AndExpression(Binary): # there is a special comparison with & instruction def __init__(self, ebpf, left, right): - super().__init__(ebpf, left, right, Opcode.AND) + super().__init__(ebpf, left, right, Opcode.AND, False, False) def __ne__(self, value): if isinstance(value, int) and value == 0: @@ -684,7 +778,7 @@ class AndComparison(SimpleComparison): # there is a special comparison with & instruction # it is the only one which has not inversion def __init__(self, ebpf, left, right): - Binary.__init__(self, ebpf, left, right, Opcode.AND) + Binary.__init__(self, ebpf, left, right, Opcode.AND, False, False) SimpleComparison.__init__(self, ebpf, left, right, Opcode.JSET) self.opcode = (Opcode.JSET, None, Opcode.JSET, None) self.invert = None @@ -726,36 +820,41 @@ class Register(Expression): """represent one EBPF register""" offset = 0 - def __init__(self, no, ebpf, long, signed): + def __init__(self, no, ebpf, long, signed, fixed=False): self.no = no self.ebpf = ebpf self.long = long self.signed = signed + self.fixed = fixed def __add__(self, value): - if isinstance(value, int) and self.long: - return Sum(self.ebpf, self, value) - else: - return super().__add__(value) + if self.long and not self.fixed: + try: + return Sum(self.ebpf, self, index(value)) + except TypeError: + pass + return super().__add__(value) __radd__ = __add__ def __sub__(self, value): - if isinstance(value, int) and self.long: - return Sum(self.ebpf, self, -value) - else: - return super().__sub__(value) + if self.long and not self.fixed: + try: + return Sum(self.ebpf, self, -index(value)) + except TypeError: + pass + return super().__sub__(value) @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): if self.no not in self.ebpf.owners: raise AssembleError("register has no value") if dst != self.no and force: self.ebpf.append(Opcode.MOV + Opcode.REG + Opcode.LONG * self.long, dst, self.no, 0, 0) - yield dst, self.long, self.signed + yield dst, self.long else: - yield self.no, self.long, self.signed + yield self.no, self.long def contains(self, no): return self.no == no @@ -771,7 +870,7 @@ class Memory(Expression): bits_to_opcode = {32: Opcode.W, 16: Opcode.H, 8: Opcode.B, 64: Opcode.DW} fmt_to_opcode = {'I': Opcode.W, 'H': Opcode.H, 'B': Opcode.B, 'Q': Opcode.DW, 'i': Opcode.W, 'h': Opcode.H, 'b': Opcode.B, 'q': Opcode.DW, - 'A': Opcode.W} + 'A': Opcode.W, 'f': Opcode.DW} def __init__(self, ebpf, fmt, address): self.ebpf = ebpf @@ -791,7 +890,7 @@ class Memory(Expression): return NotImplemented @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): with ExitStack() as exitStack: if isinstance(self.address, Sum): dst = exitStack.enter_context(self.ebpf.get_free_register(dst)) @@ -799,19 +898,19 @@ class Memory(Expression): Opcode.LD + self.fmt_to_opcode.get(self.fmt, Opcode.B), dst, self.address.left.no, self.address.right, 0) else: - dst, _, _ = exitStack.enter_context( - super().calculate(dst, long, signed, force)) + dst, _ = exitStack.enter_context( + super().calculate(dst, long, force)) if isinstance(self.fmt, tuple): self.ebpf.r[dst] &= ((1 << self.fmt[1]) - 1) << self.fmt[0] if self.fmt[0] > 0: self.ebpf.r[dst] >>= self.fmt[0] - yield dst, "B", False + yield dst, "B" else: - yield dst, self.fmt in "QqA", self.fmt.islower() + yield dst, self.fmt in "QqA" @contextmanager - def get_address(self, dst, long, signed, force=False): - with self.address.calculate(dst, True, None) as (src, _, _): + def get_address(self, dst, long, force=False): + with self.address.calculate(dst, True) as (src, _): yield src, self.fmt def contains(self, no): @@ -821,6 +920,10 @@ class Memory(Expression): def signed(self): return isinstance(self.fmt, str) and self.fmt.islower() + @property + def fixed(self): + return isinstance(self.fmt, str) and self.fmt == "f" + def __invert__(self): if not isinstance(self.fmt, tuple) or self.fmt[1] != 1: return NotImplemented @@ -870,24 +973,31 @@ class MemoryDesc: mask = ((1 << fmt[1]) - 1) << fmt[0] value = (mask & (value << self.fmt[0]) | ~mask & before) opcode = Opcode.STX - elif isinstance(value, int): - ebpf.append(Opcode.ST + bits, self.base_register, 0, - addr, value) - return elif isinstance(value, IAdd): value = value.value - if isinstance(value, int): + if not isinstance(value, Expression): + if self.fixed: + value = int(value * self.FIXED_BASE) with ebpf.get_free_register(None) as src: ebpf.r[src] = value ebpf.append(Opcode.XADD + bits, self.base_register, src, addr, 0) return opcode = Opcode.XADD - else: + elif isinstance(value, Expression): opcode = Opcode.STX - with value.calculate(None, isinstance(fmt, str) and fmt in 'qQ', - isinstance(fmt, str) and fmt.islower() - ) as (src, _, _): + else: + if self.fixed: + value = int(value * Expression.FIXED_BASE) + ebpf.append(Opcode.ST + bits, self.base_register, 0, + addr, value) + return + if self.fmt == "f" and not value.fixed: + value = value * Expression.FIXED_BASE + elif self.fmt != "f" and value.fixed: + value = value / Expression.FIXED_BASE + with value.calculate(None, isinstance(fmt, str) and fmt in 'qQf' + ) as (src, _): ebpf.append(opcode + bits, self.base_register, src, addr, 0) @@ -897,6 +1007,7 @@ class LocalVar(MemoryDesc): def __init__(self, fmt='I'): self.fmt = fmt + self.fixed = fmt == "f" def __set_name__(self, owner, name): if isinstance(self.fmt, str): @@ -926,16 +1037,13 @@ class MemoryMap: dst = addr.left.no offset = addr.right else: - dst, _, _ = exitStack.enter_context( - addr.calculate(None, True, None)) + dst, _ = exitStack.enter_context(addr.calculate(None, True)) offset = 0 - if isinstance(value, int): - self.ebpf.append(Opcode.ST + Memory.fmt_to_opcode[self.fmt], - dst, 0, offset, value) - return - elif isinstance(value, IAdd): + if isinstance(value, IAdd): value = value.value - if isinstance(value, int): + if self.fmt == "f": + value = int(value * self.FIXED_BASE) + if not isinstance(value, Expression): with self.ebpf.get_free_register(None) as src: self.ebpf.r[src] = value self.ebpf.append( @@ -943,9 +1051,15 @@ class MemoryMap: dst, src, offset, 0) return opcode = Opcode.XADD - else: + elif isinstance(value, Expression): opcode = Opcode.STX - with value.calculate(None, None, None) as (src, _, _): + else: + if self.fmt == "f": + value = int(value * self.FIXED_BASE) + self.ebpf.append(Opcode.ST + Memory.fmt_to_opcode[self.fmt], + dst, 0, offset, value) + return + with value.calculate(None, None) as (src, _): self.ebpf.append(opcode + Memory.fmt_to_opcode[self.fmt], dst, src, offset, 0) @@ -970,28 +1084,30 @@ class PseudoFd(Expression): def __init__(self, ebpf, fd): self.ebpf = ebpf self.fd = fd + self.fixed = False @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): with self.ebpf.get_free_register(dst) as dst: self.ebpf.append(Opcode.DW, dst, 1, 0, self.fd) self.ebpf.append(Opcode.W, 0, 0, 0, 0) - yield dst, long, signed + yield dst, long class ktime(Expression): """a function that returns the current ktime in ns""" def __init__(self, ebpf): self.ebpf = ebpf + self.fixed = False @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): with self.ebpf.get_free_register(dst) as dst: with self.ebpf.save_registers([i for i in range(6) if i != dst]): self.ebpf.call(FuncId.ktime_get_ns) if dst != 0: self.ebpf.r[dst] = self.ebpf.r0 - yield dst, True, False + yield dst, True class prandom(Expression): @@ -1000,13 +1116,13 @@ class prandom(Expression): self.ebpf = ebpf @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): with self.ebpf.get_free_register(dst) as dst: with self.ebpf.save_registers([i for i in range(6) if i != dst]): self.ebpf.call(FuncId.get_prandom_u32) if dst != 0: self.ebpf.r[dst] = self.ebpf.r0 - yield dst, True, False + yield dst, True class RegisterDesc: @@ -1025,29 +1141,34 @@ class RegisterDesc: class RegisterArray: - def __init__(self, ebpf, long, signed): + def __init__(self, ebpf, long, signed, fixed=False): self.ebpf = ebpf self.long = long self.signed = signed + self.fixed = fixed def __setitem__(self, no, value): self.ebpf.owners.add(no) - if isinstance(value, int): - self.ebpf._load_value(no, value) - elif isinstance(value, Expression): - with value.calculate(no, self.long, self.signed, True): + if isinstance(value, Expression): + if self.fixed and not value.fixed: + value = value * Expression.FIXED_BASE + if not self.fixed and value.fixed: + value = value / Expression.FIXED_BASE + with value.calculate(no, self.long, True): pass else: - raise AssembleError("cannot compile") + if self.fixed: + value = int(value * Expression.FIXED_BASE) + self.ebpf._load_value(no, value) def __getitem__(self, no): - return Register(no, self.ebpf, self.long, self.signed) + return Register(no, self.ebpf, self.long, self.signed, self.fixed) class Temporary(Register): - def __init__(self, ebpf, long, signed): - super().__init__(None, ebpf, long, signed) + def __init__(self, ebpf, long, signed, fixed): + super().__init__(None, ebpf, long, signed, fixed) self.nos = [] self.gfrs = [] @@ -1074,7 +1195,7 @@ class TemporaryDesc(RegisterDesc): ret = instance.__dict__.get(self.name, None) if ret is None: ret = instance.__dict__[self.name] = \ - Temporary(instance, arr.long, arr.signed) + Temporary(instance, arr.long, arr.signed, arr.fixed) return ret def __set__(self, instance, value): @@ -1111,11 +1232,17 @@ class EBPF: self.mI = MemoryMap(self, "I") self.mA = MemoryMap(self, "A") # actually I, but treat as Q self.mQ = MemoryMap(self, "Q") + self.mb = MemoryMap(self, "b") + self.mh = MemoryMap(self, "h") + self.mi = MemoryMap(self, "i") + self.mq = MemoryMap(self, "q") + self.mf = MemoryMap(self, "f") self.r = RegisterArray(self, True, False) self.sr = RegisterArray(self, True, True) self.w = RegisterArray(self, False, False) self.sw = RegisterArray(self, False, True) + self.f = RegisterArray(self, True, True, True) self.owners = {1, 10} @@ -1202,7 +1329,7 @@ class EBPF: raise AssembleError("not enough registers") def _load_value(self, no, value): - if -0x80000000 <= value < 0x80000000: + if -0x80000000 <= value < 0x100000000: self.append(Opcode.MOV + Opcode.LONG, no, 0, 0, value) else: self.append(Opcode.DW, no, 0, 0, value & 0xffffffff) @@ -1240,6 +1367,7 @@ class EBPF: stmp = TemporaryDesc(None, "sr") wtmp = TemporaryDesc(None, "w") swtmp = TemporaryDesc(None, "sw") + ftmp = TemporaryDesc(None, "f") for i in range(11): @@ -1254,6 +1382,9 @@ for i in range(10): for i in range(10): setattr(EBPF, f"sw{i}", RegisterDesc(i, "sw")) +for i in range(10): + setattr(EBPF, f"f{i}", RegisterDesc(i, "f")) + class SubProgram: stack = 0 diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py index 953eff8cfa822db87940e45e88d41f1b4abbb582..a2b949675dbc52db617da430909f0b90e17bd6a0 100644 --- a/ebpfcat/ebpf_test.py +++ b/ebpfcat/ebpf_test.py @@ -85,8 +85,8 @@ class Tests(TestCase): e.r4 -= e.r7 e.r4 *= 3 e.r4 *= e.r7 - e.r4 /= 3 - e.r4 /= e.r7 + e.r4 //= 3 + e.r4 //= e.r7 e.r4 |= 3 e.r4 |= e.r7 e.r4 &= 3 @@ -141,6 +141,8 @@ class Tests(TestCase): e.r3 = e.mH[e.r3 + 2] e.r4 = e.mI[7 + e.r8] e.r5 = e.mQ[e.r3 - 7] + e.r5 = e.mb[e.r3] >> 2 + e.r5 = e.mB[e.r3] >> 2 self.assertEqual(e.opcodes, [Instruction(opcode=114, dst=5, src=0, off=0, imm=7), Instruction(opcode=106, dst=3, src=0, off=2, imm=3), @@ -153,7 +155,122 @@ class Tests(TestCase): Instruction(opcode=113, dst=2, src=5, off=0, imm=0), Instruction(opcode=105, dst=3, src=3, off=2, imm=0), Instruction(opcode=97, dst=4, src=8, off=7, imm=0), - Instruction(opcode=121, dst=5, src=3, off=-7, imm=0)]) + Instruction(opcode=121, dst=5, src=3, off=-7, imm=0), + Instruction(opcode=O.B+O.LD, dst=5, src=3, off=0, imm=0), + Instruction(opcode=O.LONG+O.ARSH, dst=5, src=0, off=0, imm=2), + Instruction(opcode=O.B+O.LD, dst=5, src=3, off=0, imm=0), + Instruction(opcode=O.LONG+O.RSH, dst=5, src=0, off=0, imm=2), + ]) + + def test_fixed(self): + e = EBPF() + e.owners = {0, 1, 2, 3, 4, 5, 6} + e.f1 = e.r2 + 3 + e.f3 = e.r4 + 3.5 + e.f5 = e.f6 + 3 + e.r1 = e.r2 + e.f3 + e.f4 = e.f5 + e.f6 + e.r1 = 2 - e.f2 + e.r3 = 3.4 - e.r4 + e.r5 = e.f6 % 4 + + e.f1 = e.r2 * 3 + e.f3 = e.r4 * 3.5 + e.f5 = e.f6 * 3 + e.r1 = e.r2 * e.f3 + e.f4 = e.f5 * e.f6 + + e.f1 = e.r2 / 3 + e.f3 = e.r4 / 3.5 + e.f5 = e.f6 / 3 + e.r1 = e.r2 / e.f3 + e.f4 = e.f5 / e.f6 + + e.f1 = e.r2 // 3 + e.f3 = e.r4 // 3.5 + e.f5 = e.f6 // 3 + e.r1 = e.r2 // e.f3 + e.f4 = e.f5 // e.f6 + + self.maxDiff = None + self.assertEqual(e.opcodes, [ + Instruction(opcode=O.REG+O.MOV+O.LONG, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.ADD+O.LONG, dst=1, src=0, off=0, imm=3), + Instruction(opcode=O.MUL+O.LONG, dst=1, src=0, off=0, imm=100000), + Instruction(opcode=O.REG+O.MOV+O.LONG, dst=3, src=4, off=0, imm=0), + Instruction(opcode=O.MUL+O.LONG, dst=3, src=0, off=0, imm=100000), + Instruction(opcode=O.ADD+O.LONG, dst=3, src=0, off=0, imm=350000), + Instruction(opcode=O.REG+O.MOV+O.LONG, dst=5, src=6, off=0, imm=0), + Instruction(opcode=O.ADD+O.LONG, dst=5, src=0, off=0, imm=300000), + Instruction(opcode=O.REG+O.MOV+O.LONG, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.LONG+O.MUL, dst=1, src=0, off=0, imm=100000), + Instruction(opcode=O.REG+O.ADD+O.LONG, dst=1, src=3, off=0, imm=0), + Instruction(opcode=O.DIV+O.LONG, dst=1, src=0, off=0, imm=100000), + Instruction(opcode=O.REG+O.MOV+O.LONG, dst=4, src=5, off=0, imm=0), + Instruction(opcode=O.REG+O.ADD+O.LONG, dst=4, src=6, off=0, imm=0), + Instruction(opcode=O.MOV+O.LONG, dst=1, src=0, off=0, imm=200000), + Instruction(opcode=O.REG+O.SUB+O.LONG, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.LONG+O.DIV, dst=1, src=0, off=0, imm=100000), + Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=340000), + Instruction(opcode=O.REG+O.MOV+O.LONG, dst=7, src=4, off=0, imm=0), + Instruction(opcode=O.LONG+O.MUL, dst=7, src=0, off=0, imm=100000), + Instruction(opcode=O.REG+O.SUB+O.LONG, dst=3, src=7, off=0, imm=0), + Instruction(opcode=O.DIV+O.LONG, dst=3, src=0, off=0, imm=100000), + Instruction(opcode=O.REG+O.LONG+O.MOV, dst=5, src=6, off=0, imm=0), + Instruction(opcode=O.LONG+O.MOD, dst=5, src=0, off=0, imm=400000), + Instruction(opcode=O.LONG+O.DIV, dst=5, src=0, off=0, imm=100000), + + Instruction(opcode=O.REG+O.MOV+O.LONG, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.LONG+O.MUL, dst=1, src=0, off=0, imm=3), + Instruction(opcode=O.LONG+O.MUL, dst=1, src=0, off=0, imm=100000), + Instruction(opcode=O.REG+O.MOV+O.LONG, dst=3, src=4, off=0, imm=0), + Instruction(opcode=O.LONG+O.MUL, dst=3, src=0, off=0, imm=350000), + Instruction(opcode=O.REG+O.MOV+O.LONG, dst=5, src=6, off=0, imm=0), + Instruction(opcode=O.LONG+O.MUL, dst=5, src=0, off=0, imm=3), + Instruction(opcode=O.REG+O.MOV+O.LONG, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.REG+O.LONG+O.MUL, dst=1, src=3, off=0, imm=0), + Instruction(opcode=O.DIV+O.LONG, dst=1, src=0, off=0, imm=100000), + Instruction(opcode=O.REG+O.MOV+O.LONG, dst=4, src=5, off=0, imm=0), + Instruction(opcode=O.REG+O.LONG+O.MUL, dst=4, src=6, off=0, imm=0), + Instruction(opcode=O.DIV+O.LONG, dst=4, src=0, off=0, imm=100000), + + Instruction(opcode=O.LONG+O.REG+O.MOV, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.MUL+O.LONG, dst=1, src=0, off=0, imm=100000), + Instruction(opcode=O.DIV+O.LONG, dst=1, src=0, off=0, imm=3), + Instruction(opcode=O.LONG+O.REG+O.MOV, dst=3, src=4, off=0, imm=0), + Instruction(opcode=O.DW, dst=7, src=0, off=0, imm=1410065408), + Instruction(opcode=O.W, dst=0, src=0, off=0, imm=2), + Instruction(opcode=O.MUL+O.REG+O.LONG, dst=3, src=7, off=0, imm=0), + Instruction(opcode=O.DIV+O.LONG, dst=3, src=0, off=0, imm=350000), + Instruction(opcode=O.LONG+O.REG+O.MOV, dst=5, src=6, off=0, imm=0), + Instruction(opcode=O.DIV+O.LONG, dst=5, src=0, off=0, imm=3), + Instruction(opcode=O.LONG+O.REG+O.MOV, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.DW, dst=7, src=0, off=0, imm=1410065408), + Instruction(opcode=O.W, dst=0, src=0, off=0, imm=2), + Instruction(opcode=O.REG+O.LONG+O.MUL, dst=1, src=7, off=0, imm=0), + Instruction(opcode=O.DIV+O.LONG+O.REG, dst=1, src=3, off=0, imm=0), + Instruction(opcode=O.DIV+O.LONG, dst=1, src=0, off=0, imm=100000), + Instruction(opcode=O.LONG+O.REG+O.MOV, dst=4, src=5, off=0, imm=0), + Instruction(opcode=O.MUL+O.LONG, dst=4, src=0, off=0, imm=100000), + Instruction(opcode=O.DIV+O.LONG+O.REG, dst=4, src=6, off=0, imm=0), + + Instruction(opcode=O.LONG+O.REG+O.MOV, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.DIV+O.LONG, dst=1, src=0, off=0, imm=3), + Instruction(opcode=O.MUL+O.LONG, dst=1, src=0, off=0, imm=100000), + Instruction(opcode=O.LONG+O.REG+O.MOV, dst=3, src=4, off=0, imm=0), + Instruction(opcode=O.MUL+O.LONG, dst=3, src=0, off=0, imm=100000), + Instruction(opcode=O.DIV+O.LONG, dst=3, src=0, off=0, imm=350000), + Instruction(opcode=O.MUL+O.LONG, dst=3, src=0, off=0, imm=100000), + Instruction(opcode=O.LONG+O.REG+O.MOV, dst=5, src=6, off=0, imm=0), + Instruction(opcode=O.DIV+O.LONG, dst=5, src=0, off=0, imm=300000), + Instruction(opcode=O.MUL+O.LONG, dst=5, src=0, off=0, imm=100000), + Instruction(opcode=O.LONG+O.REG+O.MOV, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.MUL+O.LONG, dst=1, src=0, off=0, imm=100000), + Instruction(opcode=O.DIV+O.LONG+O.REG, dst=1, src=3, off=0, imm=0), + Instruction(opcode=O.LONG+O.REG+O.MOV, dst=4, src=5, off=0, imm=0), + Instruction(opcode=O.DIV+O.LONG+O.REG, dst=4, src=6, off=0, imm=0), + Instruction(opcode=O.MUL+O.LONG, dst=4, src=0, off=0, imm=100000), + ]) def test_local(self): class Local(EBPF): @@ -161,18 +278,26 @@ class Tests(TestCase): b = LocalVar('H') c = LocalVar('i') d = LocalVar('Q') + lf = LocalVar('f') e = Local(ProgType.XDP, "GPL") e.a = 5 e.b = e.c >> 3 e.d = e.r1 + e.lf = 7 + e.b = e.f1 self.assertEqual(e.opcodes, [ Instruction(opcode=O.B+O.ST, dst=10, src=0, off=-1, imm=5), Instruction(opcode=O.W+O.LD, dst=0, src=10, off=-8, imm=0), Instruction(opcode=O.ARSH, dst=0, src=0, off=0, imm=3), Instruction(opcode=O.REG+O.STX, dst=10, src=0, off=-4, imm=0), - Instruction(opcode=O.DW+O.STX, dst=10, src=1, off=-16, imm=0)]) + Instruction(opcode=O.DW+O.STX, dst=10, src=1, off=-16, imm=0), + Instruction(opcode=O.DW+O.ST, dst=10, src=0, off=-20, imm=700000), + Instruction(opcode=O.LONG+O.REG+O.MOV, dst=0, src=1, off=0, imm=0), + Instruction(opcode=O.DIV, dst=0, src=0, off=0, imm=100000), + Instruction(opcode=O.REG+O.STX, dst=10, src=0, off=-4, imm=0), + ]) def test_local_bits(self): class Local(EBPF): @@ -368,7 +493,7 @@ class Tests(TestCase): def test_with(self): e = EBPF() - e.owners = set(range(11)) + e.owners = set(range(9)) with e.r2 > 3 as Else: e.r2 = 5 with Else: @@ -379,8 +504,16 @@ class Tests(TestCase): e.r5 = 7 with Else: e.r7 = 8 - self.assertEqual(e.opcodes, - [Instruction(opcode=0xb5, dst=2, src=0, off=2, imm=3), + with e.f4 > 3: + pass + with 3 > e.f4: + pass + with e.r4 > 3.5: + pass + with e.f4 > e.f2: + pass + self.assertEqual(e.opcodes, [ + Instruction(opcode=0xb5, dst=2, src=0, off=2, imm=3), Instruction(opcode=0xb7, dst=2, src=0, off=0, imm=5), Instruction(opcode=0x5, dst=0, src=0, off=1, imm=0), Instruction(opcode=O.MOV+O.LONG, dst=6, src=0, off=0, imm=7), @@ -389,7 +522,14 @@ class Tests(TestCase): Instruction(opcode=O.JLE, dst=4, src=0, off=2, imm=3), Instruction(opcode=O.MOV+O.LONG, dst=5, src=0, off=0, imm=7), Instruction(opcode=O.JMP, dst=0, src=0, off=1, imm=0), - Instruction(opcode=O.MOV+O.LONG, dst=7, src=0, off=0, imm=8)]) + Instruction(opcode=O.MOV+O.LONG, dst=7, src=0, off=0, imm=8), + Instruction(opcode=O.JSLE, dst=4, src=0, off=0, imm=300000), + Instruction(opcode=O.JSGE, dst=4, src=0, off=0, imm=300000), + Instruction(opcode=O.REG+O.MOV+O.LONG, dst=9, src=4, off=0, imm=0), + Instruction(opcode=O.MUL+O.LONG, dst=9, src=0, off=0, imm=100000), + Instruction(opcode=O.JLE, dst=9, src=0, off=0, imm=350000), + Instruction(opcode=O.REG+O.JSLE, dst=4, src=2, off=0, imm=0), + ]) def test_with_inversion(self): e = EBPF() @@ -440,7 +580,6 @@ class Tests(TestCase): with Else: e.r3 = 7 e.r4 = 3 - self.maxDiff = None self.assertEqual(e.opcodes, [ Instruction(opcode=O.JGT, dst=2, src=0, off=1, imm=3), Instruction(opcode=O.JLE, dst=3, src=0, off=1, imm=2), @@ -481,11 +620,20 @@ class Tests(TestCase): e = EBPF() e.r3 = 0x1234567890 e.r4 = e.get_fd(7) + e.r3 = e.r4 + 0x1234567890 + e.r3 = 0x90000000 + self.assertEqual(e.opcodes, [ Instruction(opcode=24, dst=3, src=0, off=0, imm=878082192), Instruction(opcode=0, dst=0, src=0, off=0, imm=18), Instruction(opcode=24, dst=4, src=1, off=0, imm=7), - Instruction(opcode=0, dst=0, src=0, off=0, imm=0)]) + Instruction(opcode=0, dst=0, src=0, off=0, imm=0), + Instruction(opcode=O.REG+O.LONG+O.MOV, dst=3, src=4, off=0, imm=0), + Instruction(opcode=O.DW, dst=0, src=0, off=0, imm=878082192), + Instruction(opcode=O.W, dst=0, src=0, off=0, imm=18), + Instruction(opcode=O.LONG+O.REG+O.ADD, dst=3, src=0, off=0, imm=0), + Instruction(opcode=O.LONG+O.MOV, dst=3, src=0, off=0, imm=2415919104), + ]) def test_simple_binary(self): e = EBPF() @@ -555,7 +703,7 @@ class Tests(TestCase): def test_reverse_binary(self): e = EBPF() e.owners = {0, 1, 2, 3} - e.r3 = 7 / (e.r2 + 2) + e.r3 = 7 // (e.r2 + 2) e.r3 = 7 << e.r2 e.r3 = 7 % (e.r2 + 3) e.r3 = 7 >> e.r2 @@ -587,10 +735,15 @@ class Tests(TestCase): def test_absolute(self): e = EBPF() e.r7 = abs(e.r1) + e.f3 = abs(e.f1) self.assertEqual(e.opcodes, [ Instruction(opcode=O.LONG+O.REG+O.MOV, dst=7, src=1, off=0, imm=0), - Instruction(opcode=O.JGE, dst=7, src=0, off=1, imm=0), - Instruction(opcode=O.LONG+O.NEG, dst=7, src=0, off=0, imm=0)]) + Instruction(opcode=O.JSGE, dst=7, src=0, off=1, imm=0), + Instruction(opcode=O.LONG+O.NEG, dst=7, src=0, off=0, imm=0), + Instruction(opcode=O.REG+O.MOV+O.LONG, dst=3, src=1, off=0, imm=0), + Instruction(opcode=O.JSGE, dst=3, src=0, off=1, imm=0), + Instruction(opcode=O.NEG+O.LONG, dst=3, src=0, off=0, imm=0), + ]) def test_jump_data(self): e = EBPF() @@ -642,7 +795,6 @@ class Tests(TestCase): e.r8 = e.r1 def test_binary_alloc(self): - self.maxDiff = None e = EBPF() e.r3 = e.r1 - (2 * e.r10) e.mH[e.r10 - 10] = 2 * e.r3 @@ -698,6 +850,10 @@ class Tests(TestCase): e.r7 = e.tmp e.tmp = 2 e.r3 = e.tmp + with e.ftmp: + e.ftmp = 3 + e.r3 = e.ftmp + e.ftmp = e.r3 * 3.5 self.assertEqual(e.opcodes, [ Instruction(opcode=O.MOV+O.LONG, dst=0, src=0, off=0, imm=7), Instruction(opcode=O.MOV+O.LONG, dst=2, src=0, off=0, imm=3), @@ -705,7 +861,12 @@ class Tests(TestCase): Instruction(opcode=O.MOV+O.LONG, dst=4, src=0, off=0, imm=5), Instruction(opcode=O.MOV+O.LONG+O.REG, dst=7, src=4, off=0, imm=0), Instruction(opcode=O.MOV+O.LONG, dst=2, src=0, off=0, imm=2), - Instruction(opcode=O.MOV+O.LONG+O.REG, dst=3, src=2, off=0, imm=0) + Instruction(opcode=O.MOV+O.LONG+O.REG, dst=3, src=2, off=0, imm=0), + Instruction(opcode=O.MOV+O.LONG, dst=2, src=0, off=0, imm=300000), + Instruction(opcode=O.MOV+O.REG+O.LONG, dst=3, src=2, off=0, imm=0), + Instruction(opcode=O.DIV+O.LONG, dst=3, src=0, off=0, imm=100000), + Instruction(opcode=O.MOV+O.REG+O.LONG, dst=2, src=3, off=0, imm=0), + Instruction(opcode=O.LONG+O.MUL, dst=2, src=0, off=0, imm=350000), ]) def test_ktime(self): @@ -761,15 +922,18 @@ class KernelTests(TestCase): class Global(EBPF): map = ArrayMap() ar = map.globalVar() - aw = map.globalVar() + aw = map.globalVar("h") class Sub(SubProgram): br = Global.map.globalVar() - bw = Global.map.globalVar() + bw = Global.map.globalVar("h") + bf = Global.map.globalVar("f") def program(self): + self.bw = 4 self.br -= -33 self.bw = self.br + 3 + self.bf = self.br / 3.5 + self.bf s1 = Sub() s2 = Sub() @@ -787,9 +951,11 @@ class KernelTests(TestCase): self.assertEqual(e.aw, 11) self.assertEqual(s1.br, 33) self.assertEqual(s1.bw, 36) + self.assertEqual(s2.bf, 9.42857) s1.br = 3 s2.br *= 5 e.ar = 1111 + s2.bf = 1.3 self.assertEqual(e.ar, 1111) self.assertEqual(e.aw, 11) self.assertEqual(s1.br, 3) @@ -803,6 +969,7 @@ class KernelTests(TestCase): self.assertEqual(s1.bw, 39) self.assertEqual(s2.br, 198) self.assertEqual(s2.bw, 201) + self.assertEqual(s2.bf, 57.87142) def test_minimal(self): class Local(EBPF): diff --git a/ebpfcat/hashmap.py b/ebpfcat/hashmap.py index 76f851872bef21b8a386461ebfe62941d01c200f..8e5e61c86ab792634dff95218c3a11c8ef34e3b6 100644 --- a/ebpfcat/hashmap.py +++ b/ebpfcat/hashmap.py @@ -28,11 +28,10 @@ class HashGlobalVar(Expression): self.count = count self.fmt = fmt self.signed = fmt.islower() + self.fixed = fmt == "f" @contextmanager - def get_address(self, dst, long, signed, force=False): - if signed != self.fmt.islower(): - raise AssembleError("HashMap variable has wrong signedness") + def get_address(self, dst, long, force=False): with self.ebpf.save_registers([i for i in range(6) if i != dst]), \ self.ebpf.get_stack(4) as stack: self.ebpf.append(Opcode.ST, 10, 0, stack, self.count) @@ -78,7 +77,7 @@ class HashGlobalVarDesc: pack("q" if self.fmt.islower() else "Q", value), 0) return with ebpf.save_registers([3]): - with value.get_address(3, True, self.fmt.islower(), True): + with value.get_address(3, True, True): with ebpf.save_registers([0, 1, 2, 4, 5]), \ ebpf.get_stack(4) as stack: ebpf.r1 = ebpf.get_fd(ebpf.__dict__[self.name].fd)