diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py index c21bd3e651f777cb0070913e150f7555fe40f8db..879759f6bd15312360a94fb0190ad7af1974956f 100644 --- a/ebpfcat/ebpf.py +++ b/ebpfcat/ebpf.py @@ -23,6 +23,7 @@ from struct import pack, unpack, calcsize from enum import Enum from . import bpf +from .util import sub Instruction = namedtuple("Instruction", ["opcode", "dst", "src", "off", "imm"]) @@ -229,6 +230,8 @@ class Opcode(Enum): ST = 0x62 STX = 0x63 XADD = 0xc3 + LE = 0xd4 + BE = 0xdc def __mul__(self, value): if value: @@ -269,15 +272,15 @@ class AssembleError(Exception): def comparison(uposop, unegop, sposop, snegop): def ret(self, value): - valuefixed, value = fixedvalue(value) + value = ensure_expression(self.ebpf, value) myself = self - if self.fixed != valuefixed: + if self.fixed != value.fixed: if self.fixed: - value = value * self.FIXED_BASE + value *= self.FIXED_BASE else: - myself = self * self.FIXED_BASE + myself *= self.FIXED_BASE - if self.signed or issigned(value): + if self.signed or value.signed: return SimpleComparison(self.ebpf, myself, value, (sposop, snegop)) else: return SimpleComparison(self.ebpf, myself, value, (uposop, unegop)) @@ -369,7 +372,7 @@ class SimpleComparison(Comparison): def compare(self, negative): with self.left.calculate(None, None) as (self.dst, _): with ExitStack() as exitStack: - if isinstance(self.right, Expression): + if not self.right.small_constant: self.src, _ = exitStack.enter_context( self.right.calculate(None, None)) self.origin = len(self.ebpf.opcodes) @@ -382,14 +385,15 @@ class SimpleComparison(Comparison): if self.opcode == Opcode.JMP: inst = Instruction(Opcode.JMP, 0, 0, len(self.ebpf.opcodes) - self.origin - 1, 0) - elif isinstance(self.right, Expression): + elif self.right.small_constant: inst = Instruction( - self.opcode + Opcode.REG, self.dst, self.src, - len(self.ebpf.opcodes) - self.origin - 1, 0) + self.opcode, self.dst, 0, + len(self.ebpf.opcodes) - self.origin - 1, + int(self.right.value)) else: inst = Instruction( - self.opcode, self.dst, 0, - len(self.ebpf.opcodes) - self.origin - 1, self.right) + self.opcode + Opcode.REG, self.dst, self.src, + len(self.ebpf.opcodes) - self.origin - 1, 0) self.ebpf.opcodes[self.origin] = inst if not retarget: self.ebpf.owners, self.owners = \ @@ -431,137 +435,110 @@ class InvertComparison(Comparison): self.value.target(retarget) -def issigned(value): +def ensure_expression(ebpf, value): if isinstance(value, Expression): - return value.signed + return value else: - return value < 0 - - -def fixedvalue(value): - try: - return False, index(value) - except TypeError: - try: - return True, int(float(value) * Expression.FIXED_BASE) - except TypeError: - return value.fixed, value + return Constant(ebpf, value) class Expression: """the base class for all numerical expressions""" FIXED_BASE = 100000 + small_constant = False def _binary(self, value, opcode): + value = ensure_expression(self.ebpf, value) return Binary(self.ebpf, self, value, opcode, - self.signed or issigned(value), False) + self.signed or value.signed, False) __ror__ = __or__ = lambda self, value: self._binary(value, Opcode.OR) __lshift__ = lambda self, value: self._binary(value, Opcode.LSH) + __rlshift__ = lambda self, value: Constant(self.ebpf, value) << self __rxor__ = __xor__ = lambda self, value: self._binary(value, Opcode.XOR) __gt__ = comparison(Opcode.JGT, Opcode.JLE, Opcode.JSGT, Opcode.JSLE) __ge__ = comparison(Opcode.JGE, Opcode.JLT, Opcode.JSGE, Opcode.JSLT) __lt__ = comparison(Opcode.JLT, Opcode.JGE, Opcode.JSLT, Opcode.JSGE) __le__ = comparison(Opcode.JLE, Opcode.JGT, Opcode.JSLE, Opcode.JSGT) + __ne__ = comparison(Opcode.JNE, Opcode.JEQ, Opcode.JNE, Opcode.JEQ) def _sum(self, value, opcode): - valuefixed, value = fixedvalue(value) + value = ensure_expression(self.ebpf, value) myself = self - if self.fixed != valuefixed: + if self.fixed != value.fixed: if self.fixed: - value = value * self.FIXED_BASE + value *= self.FIXED_BASE else: - myself = self * self.FIXED_BASE + myself *= self.FIXED_BASE return Binary(self.ebpf, myself, value, opcode, - self.signed or issigned(value), self.fixed or valuefixed) - - def _rsum(self, value, opcode): - valuefixed, value = fixedvalue(value) - myself = self - if self.fixed != valuefixed: - if self.fixed: - value = value * self.FIXED_BASE - else: - myself = self * self.FIXED_BASE - - return ReverseBinary( - self.ebpf, value, myself, opcode, - self.signed or issigned(value), self.fixed or valuefixed) + self.signed or value.signed, self.fixed or value.fixed) __radd__ = __add__ = lambda self, value: self._sum(value, Opcode.ADD) __sub__ = lambda self, value: self._sum(value, Opcode.SUB) - __rsub__ = lambda self, value: self._rsum(value, Opcode.SUB) + __rsub__ = lambda self, value: Constant(self.ebpf, value) - self __mod__ = lambda self, value: self._sum(value, Opcode.MOD) - __rmod__ = lambda self, value: self._rsum(value, Opcode.MOD) + __rmod__ = lambda self, value: Constant(self.ebpf, value) % self def __mul__(self, value): - valuefixed, value = fixedvalue(value) + value = ensure_expression(self.ebpf, value) ret = Binary(self.ebpf, self, value, Opcode.MUL, - self.signed or issigned(value), self.fixed or valuefixed) - if self.fixed and valuefixed: - ret = ret / self.FIXED_BASE + self.signed or value.signed, self.fixed or value.fixed) + if self.fixed and value.fixed: + ret /= self.FIXED_BASE return ret __rmul__ = __mul__ def __truediv__(self, value): - valuefixed, value = fixedvalue(value) + value = ensure_expression(self.ebpf, value) myself = self - if not self.fixed and valuefixed: - myself = myself * self.FIXED_BASE ** 2 - elif self.fixed == valuefixed: - myself = myself * self.FIXED_BASE + if not self.fixed and value.fixed: + myself *= self.FIXED_BASE ** 2 + elif self.fixed == value.fixed: + myself *= self.FIXED_BASE return Binary(self.ebpf, myself, value, Opcode.DIV, - self.signed or issigned(value), True) + self.signed or value.signed, True) - def __rtruediv__(self, value): - if self.fixed: - value = int(value * self.FIXED_BASE ** 2) - else: - value = int(value * self.FIXED_BASE) - return ReverseBinary(self.ebpf, value, self, Opcode.DIV, - self.signed or issigned(value), True) + def _reverse(self, op, value): + return op(Constant(self.ebpf, value), self) + + __rtruediv__ = lambda self, value: Constant(self.ebpf, value) / self def __floordiv__(self, value): - valuefixed, value = fixedvalue(value) + value = ensure_expression(self.ebpf, value) myself = self - if not self.fixed and valuefixed: - myself = myself * self.FIXED_BASE - elif self.fixed and not valuefixed: - value = value * self.FIXED_BASE + if not self.fixed and value.fixed: + myself *= self.FIXED_BASE + elif self.fixed and not value.fixed: + value *= self.FIXED_BASE return Binary(self.ebpf, myself, value, Opcode.DIV, - self.signed or issigned(value), False) + self.signed or value.signed, False) def __rfloordiv__(self, value): if self.fixed: - value = int(value * self.FIXED_BASE) + value = Constant(self.ebpf, value) + if not value.fixed: + value *= self.FIXED_BASE else: - value = int(value) - return ReverseBinary(self.ebpf, value, self, Opcode.DIV, - self.signed or issigned(value), False) + value = Constant(self.ebpf, int(value)) + + return Binary(self.ebpf, value, self, Opcode.DIV, + self.signed or value.signed, False) def __rshift__(self, value): opcode = Opcode.ARSH if self.signed else Opcode.RSH - return Binary(self.ebpf, self, value, opcode, self.signed, False) + return Binary(self.ebpf, self, ensure_expression(self.ebpf, value), + opcode, self.signed, False) - def __rrshift__(self, value): - opcode = Opcode.ARSH if value < 0 else Opcode.RSH - return ReverseBinary(self.ebpf, value, self, opcode, value < 0, False) - - def __rlshift__(self, value): - return ReverseBinary(self.ebpf, value, self, Opcode.LSH, - value < 0, False) + __rrshift__ = lambda self, value: Constant(self.ebpf, value) >> self def __and__(self, value): - return AndExpression(self.ebpf, self, value) - - def __ne__(self, value): - return SimpleComparison(self.ebpf, self, value, - (Opcode.JNE, Opcode.JEQ)) + return AndExpression(self.ebpf, self, + ensure_expression(self.ebpf, value)) def __eq__(self, value): return ~(self != value) @@ -569,10 +546,15 @@ class Expression: __rand__ = __and__ def __neg__(self): - return Negate(self.ebpf, self) + return Negate(self) def __abs__(self): - return Absolute(self.ebpf, self) + return Absolute(self) + + def switch_endian(self, fmt): + if isinstance(fmt, str) and len(fmt) > 1: + return SwitchEndian(self, fmt) + return self def __bool__(self): raise AssembleError("Expression only has a value at execution time") @@ -611,7 +593,7 @@ class Expression: """ with self.ebpf.get_free_register(dst) as dst: with self.get_address(dst, long) as (src, fmt): - self.ebpf.append(Opcode.LD + Memory.fmt_to_opcode[fmt], + self.ebpf.append(Opcode.LD + fmt_to_opcode(fmt), dst, src, 0, 0) yield dst, long @@ -657,22 +639,16 @@ class Binary(Expression): with self.left.calculate(dst, long, True) as (dst, l_long): if long is None: long = l_long - if isinstance(self.right, Expression): + if self.right.small_constant: + self.ebpf.append(self.operator + Opcode.LONG * long, + dst, 0, 0, int(self.right.value)) + else: with self.right.calculate(None, long) as (src, r_long): self.ebpf.append( self.operator + Opcode.REG + Opcode.LONG * ((r_long or l_long) if long is None else long), dst, src, 0, 0) - elif -0x80000000 <= self.right < 0x100000000: - self.ebpf.append(self.operator + Opcode.LONG * long, - dst, 0, 0, self.right) - else: - with self.ebpf.get_free_register(None) as src: - self.ebpf._load_value(src, self.right) - self.ebpf.append( - self.operator + Opcode.REG + Opcode.LONG, - dst, src, 0, 0) if orig_dst is None or orig_dst == dst: yield dst, long return @@ -684,60 +660,54 @@ class Binary(Expression): and self.right.contains(no)) -class ReverseBinary(Expression): - def __init__(self, ebpf, left, right, operator, signed, fixed): - self.ebpf = ebpf - self.left = left - self.right = right - self.operator = operator - self.signed = signed - self.fixed = fixed +class Unary(Expression): + def __init__(self, arg): + self.arg = arg + self.ebpf = arg.ebpf + self.signed = arg.signed + self.fixed = arg.fixed @contextmanager def calculate(self, dst, long, force=False): - with self.ebpf.get_free_register(dst) as dst: - self.ebpf._load_value(dst, self.left) - with self.right.calculate(None, long) as (src, long): - self.ebpf.append(self.operator + Opcode.LONG * long - + Opcode.REG, dst, src, 0, 0) + with self.arg.calculate(dst, long, force) as (dst, long): + self.calculate_unary(dst, long) yield dst, long def contains(self, no): - return self.right.contains(no) + return self.arg.contains(no) -class Negate(Expression): - def __init__(self, ebpf, arg): - self.ebpf = ebpf - self.arg = arg +class Negate(Unary): + def __init__(self, arg): + super().__init__(arg) self.signed = True - self.fixed = arg.fixed - @contextmanager - def calculate(self, dst, long, force=False): - with self.arg.calculate(dst, long, force) as (dst, long): - self.ebpf.append(Opcode.NEG + Opcode.LONG * long, dst, 0, 0, 0) - yield dst, long + def calculate_unary(self, dst, long): + self.ebpf.append(Opcode.NEG + Opcode.LONG * long, dst, 0, 0, 0) - def contains(self, no): - return self.arg.contains(no) +class Absolute(Unary): + def __init__(self, arg): + super().__init__(arg) + self.signed = False -class Absolute(Expression): - def __init__(self, ebpf, arg): - self.ebpf = ebpf - self.arg = arg - self.fixed = arg.fixed + def calculate_unary(self, dst, long): + with self.ebpf.sr[dst] < 0: + self.ebpf.sr[dst] = -self.ebpf.sr[dst] - @contextmanager - def calculate(self, dst, long, force=False): - with self.arg.calculate(dst, long, force) as (dst, long): - with self.ebpf.sr[dst] < 0: - self.ebpf.sr[dst] = -self.ebpf.sr[dst] - yield dst, long - def contains(self, no): - return self.arg.contains(no) +class SwitchEndian(Unary): + def __init__(self, arg, fmt): + super().__init__(arg) + self.fmt = fmt + + def calculate_unary(self, dst, long): + endian, size = self.fmt + if endian == "<": + opcode = Opcode.LE + elif endian in ">!": + opcode = Opcode.BE + self.ebpf.append(opcode, dst, 0, 0, calcsize(size) * 8) class Sum(Binary): @@ -746,11 +716,11 @@ class Sum(Binary): this is used to optimize memory addressing code. """ def __init__(self, ebpf, left, right): - super().__init__(ebpf, left, right, Opcode.ADD, right < 0, False) + super().__init__(ebpf, left, right, Opcode.ADD, right.value < 0, False) def __add__(self, value): try: - return Sum(self.ebpf, self.left, self.right + index(value)) + self.right.value += index(value) except TypeError: return super().__add__(value) @@ -758,7 +728,7 @@ class Sum(Binary): def __sub__(self, value): try: - return Sum(self.ebpf, self.left, self.right - index(value)) + self.right.value -= index(value) except TypeError: return super().__add__(value) @@ -816,6 +786,42 @@ class AndComparison(SimpleComparison): self.ebpf.opcodes.append(None) return self +class Constant(Expression): + def __init__(self, ebpf, value): + try: + self.value = index(value) + self.fixed = False + except TypeError: + self.value = float(value) * Expression.FIXED_BASE + self.fixed = True + self.ebpf = ebpf + self.signed = value < 0 + + @property + def small_constant(self): + return -0x80000000 <= self.value < 0x80000000 + + def __imul__(self, value): + self.value *= value + return self + + @contextmanager + def calculate(self, dst, long, force=False): + value = int(self.value) + with self.ebpf.get_free_register(dst) as dst: + if self.small_constant: + self.ebpf.append(Opcode.MOV + Opcode.LONG, dst, 0, 0, value) + else: + self.ebpf.append(Opcode.DW, dst, 0, 0, value & 0xffffffff) + self.ebpf.append(Opcode.W, 0, 0, 0, value >> 32) + yield dst, not (-0x80000000 <= value < 0x100000000) + + def switch_endian(self, fmt): + if not isinstance(fmt, str) or len(fmt) == 1: + return self + return Constant(self.ebpf, *unpack(fmt, pack(fmt[-1], self.value))) + + class Register(Expression): """represent one EBPF register""" offset = 0 @@ -830,7 +836,7 @@ class Register(Expression): def __add__(self, value): if self.long and not self.fixed: try: - return Sum(self.ebpf, self, index(value)) + return Sum(self.ebpf, self, Constant(self.ebpf, index(value))) except TypeError: pass return super().__add__(value) @@ -840,7 +846,7 @@ class Register(Expression): def __sub__(self, value): if self.long and not self.fixed: try: - return Sum(self.ebpf, self, -index(value)) + return Sum(self.ebpf, self, Constant(self.ebpf, -index(value))) except TypeError: pass return super().__sub__(value) @@ -862,15 +868,24 @@ class Register(Expression): class IAdd: """represent an in-place addition""" - def __init__(self, value): - self.value = value + def __init__(self, ebpf, value): + if isinstance(value, Expression): + self.value = value + else: + self.value = Constant(ebpf, value) -class Memory(Expression): - bits_to_opcode = {32: Opcode.W, 16: Opcode.H, 8: Opcode.B, 64: Opcode.DW} +def fmt_to_opcode(fmt): fmt_to_opcode = {'I': Opcode.W, 'H': Opcode.H, 'B': Opcode.B, 'Q': Opcode.DW, 'i': Opcode.W, 'h': Opcode.H, 'b': Opcode.B, 'q': Opcode.DW, 'A': Opcode.W, 'x': Opcode.DW} + if isinstance(fmt, str): + return fmt_to_opcode[fmt[-1]] + else: + return Opcode.B + +class Memory(Expression): + bits_to_opcode = {32: Opcode.W, 16: Opcode.H, 8: Opcode.B, 64: Opcode.DW} def __init__(self, ebpf, fmt, address): self.ebpf = ebpf @@ -878,25 +893,30 @@ class Memory(Expression): self.address = address def __iadd__(self, value): - if self.fmt in "qQiI": - return IAdd(value) + if self.fmt in "qQiIx": + return IAdd(self.ebpf, value) else: return NotImplemented def __isub__(self, value): - if self.fmt in "qQiI": - return IAdd(-value) + if self.fmt in "qQiIx": + return IAdd(self.ebpf, -value) else: return NotImplemented @contextmanager def calculate(self, dst, long, force=False): + if self.has_endian(): + with self.without_endian().switch_endian(self.fmt) \ + .calculate(dst, long, force) as (dst, long): + yield dst, long + return with ExitStack() as exitStack: if isinstance(self.address, Sum): dst = exitStack.enter_context(self.ebpf.get_free_register(dst)) - self.ebpf.append( - Opcode.LD + self.fmt_to_opcode.get(self.fmt, Opcode.B), - dst, self.address.left.no, self.address.right, 0) + opcode = fmt_to_opcode(self.fmt) + self.ebpf.append(Opcode.LD + opcode, dst, self.address.left.no, + self.address.right.value, 0) else: dst, _ = exitStack.enter_context( super().calculate(dst, long, force)) @@ -906,7 +926,7 @@ class Memory(Expression): self.ebpf.r[dst] >>= self.fmt[0] yield dst, "B" else: - yield dst, self.fmt in "QqA" + yield dst, self.fmt[-1] in "QqAx" @contextmanager def get_address(self, dst, long, force=False): @@ -924,6 +944,14 @@ class Memory(Expression): def fixed(self): return isinstance(self.fmt, str) and self.fmt == "x" + def has_endian(self): + return isinstance(self.fmt, str) and len(self.fmt) > 1 + + def without_endian(self): + if self.has_endian(): + return Memory(self.ebpf, self.fmt[-1], self.address) + return self + def __invert__(self): if not isinstance(self.fmt, tuple) or self.fmt[1] != 1: return NotImplemented @@ -936,6 +964,56 @@ class Memory(Expression): return Memory(self.ebpf, "B", self.address) & mask != 0 return super().__ne__(value) + def _set(self, value): + opcode = Opcode.STX + with ExitStack() as exitStack: + if isinstance(self.fmt, tuple): + pos, bits = self.fmt + self.fmt = "B" + if bits == 1: + try: + if value: + value = self | (1 << pos) + else: + value = self & ~(1 << pos) + except AssembleError: + exitStack.enter_context(self.ebpf.wtmp) + with value as Else: + self.ebpf.wtmp = self | (1 << pos) + with Else: + self.ebpf.wtmp = self & ~(1 << pos) + value = self.ebpf.wtmp + else: + mask = ((1 << bits) - 1) << pos + value = (mask & (value << pos) | ~mask & self) + elif isinstance(value, IAdd): + value = value.value + opcode = Opcode.XADD + elif not isinstance(value, Expression): + value = Constant(self.ebpf, value) + if self.fmt == "x" and not value.fixed: + value *= Expression.FIXED_BASE + elif self.fmt != "x" and value.fixed: + value /= Expression.FIXED_BASE + if isinstance(self.address, Sum): + dst = self.address.left.no + offset = self.address.right.value + else: + dst, _ = exitStack.enter_context( + self.address.calculate(None, True)) + offset = 0 + value = value.switch_endian(self.fmt) + if value.small_constant and opcode == Opcode.STX: + self.ebpf.append(Opcode.ST + fmt_to_opcode(self.fmt), dst, 0, + offset, int(value.value)) + return + src, _ = exitStack.enter_context( + value.calculate(None, isinstance(self.fmt, str) + and self.fmt[-1] in 'qQx')) + self.ebpf.append(opcode + fmt_to_opcode(self.fmt), + dst, src, offset, 0) + + class MemoryDesc: """A base class used by descriptors for memory @@ -944,6 +1022,9 @@ class MemoryDesc: defined by the member variable `base_register` in deriving classes. """ + + fixed = False # only selected memory can have fixe value vars + def __get__(self, instance, owner): if instance is None: return self @@ -952,54 +1033,10 @@ class MemoryDesc: instance.ebpf.r[self.base_register] + addr) def __set__(self, instance, value): - ebpf = instance.ebpf fmt, addr = self.fmt_addr(instance) - bits = Memory.fmt_to_opcode.get(fmt, Opcode.B) - if isinstance(fmt, tuple): - before = Memory(ebpf, "B", ebpf.r[self.base_register] + addr) - if fmt[1] == 1: - try: - if value: - value = before | (1 << fmt[0]) - else: - value = before & ~(1 << fmt[0]) - except AssembleError: - with ebpf.wtmp: - with value as Else: - ebpf.wtmp = before | (1 << fmt[0]) - with Else: - ebpf.wtmp = before & ~(1 << fmt[0]) - else: - mask = ((1 << fmt[1]) - 1) << fmt[0] - value = (mask & (value << self.fmt[0]) | ~mask & before) - opcode = Opcode.STX - elif isinstance(value, IAdd): - value = value.value - if not isinstance(value, Expression): - if self.fixed: - value = int(value * self.FIXED_BASE) - with ebpf.get_free_register(None) as src: - ebpf.r[src] = value - ebpf.append(Opcode.XADD + bits, self.base_register, - src, addr, 0) - return - opcode = Opcode.XADD - elif isinstance(value, Expression): - opcode = Opcode.STX - else: - if self.fixed: - value = int(value * Expression.FIXED_BASE) - ebpf.append(Opcode.ST + bits, self.base_register, 0, - addr, value) - return - if self.fmt == "x" and not value.fixed: - value = value * Expression.FIXED_BASE - elif self.fmt != "x" and value.fixed: - value = value / Expression.FIXED_BASE - with value.calculate(None, isinstance(fmt, str) and fmt in 'qQx' - ) as (src, _): - ebpf.append(opcode + bits, self.base_register, src, addr, 0) - + memory = Memory(instance.ebpf, fmt, + instance.ebpf.r[self.base_register] + addr) + memory._set(value) class LocalVar(MemoryDesc): """variables on the stack""" @@ -1034,36 +1071,8 @@ class MemoryMap: self.fmt = fmt def __setitem__(self, addr, value): - with ExitStack() as exitStack: - if isinstance(addr, Sum): - dst = addr.left.no - offset = addr.right - else: - dst, _ = exitStack.enter_context(addr.calculate(None, True)) - offset = 0 - if isinstance(value, IAdd): - value = value.value - if self.fmt == "x": - value = int(value * self.FIXED_BASE) - if not isinstance(value, Expression): - with self.ebpf.get_free_register(None) as src: - self.ebpf.r[src] = value - self.ebpf.append( - Opcode.XADD + Memory.fmt_to_opcode[self.fmt], - dst, src, offset, 0) - return - opcode = Opcode.XADD - elif isinstance(value, Expression): - opcode = Opcode.STX - else: - if self.fmt == "x": - value = int(value * self.FIXED_BASE) - self.ebpf.append(Opcode.ST + Memory.fmt_to_opcode[self.fmt], - dst, 0, offset, value) - return - with value.calculate(None, None) as (src, _): - self.ebpf.append(opcode + Memory.fmt_to_opcode[self.fmt], - dst, src, offset, 0) + memory = Memory(self.ebpf, self.fmt, addr) + memory._set(value) def __getitem__(self, addr): if isinstance(addr, Register): @@ -1151,17 +1160,13 @@ class RegisterArray: def __setitem__(self, no, value): self.ebpf.owners.add(no) - if isinstance(value, Expression): - if self.fixed and not value.fixed: - value = value * Expression.FIXED_BASE - if not self.fixed and value.fixed: - value = value / Expression.FIXED_BASE - with value.calculate(no, self.long, True): - pass - else: - if self.fixed: - value = int(value * Expression.FIXED_BASE) - self.ebpf._load_value(no, value) + value = ensure_expression(self.ebpf, value) + if self.fixed and not value.fixed: + value *= Expression.FIXED_BASE + elif not self.fixed and value.fixed: + value /= Expression.FIXED_BASE + with value.calculate(no, self.long, True): + pass def __getitem__(self, no): return Register(no, self.ebpf, self.long, self.signed, self.fixed) @@ -1262,9 +1267,19 @@ class EBPF: def append(self, opcode, dst, src, off, imm): self.opcodes.append(Instruction(opcode, dst, src, off, imm)) + def append_endian(self, fmt, dst): + if not isinstance(fmt, str) or len(fmt) != 2: + return + endian, size = fmt + if endian == "<": + opcode = Opcode.LE + elif endian in ">!": + opcode = Opcode.BE + self.append(opcode, dst, 0, 0, calcsize(fmt) * 8) + def assemble(self): """return the assembled program""" - self.program() + sub(EBPF, self).program() return b"".join( pack("<BBHI", i.opcode.value, i.dst | i.src << 4, i.off % 0x10000, i.imm % 0x100000000) @@ -1330,13 +1345,6 @@ class EBPF: return raise AssembleError("not enough registers") - def _load_value(self, no, value): - if -0x80000000 <= value < 0x100000000: - self.append(Opcode.MOV + Opcode.LONG, no, 0, 0, value) - else: - self.append(Opcode.DW, no, 0, 0, value & 0xffffffff) - self.append(Opcode.W, 0, 0, 0, value >> 32) - @contextmanager def save_registers(self, registers): oldowners = self.owners.copy() diff --git a/ebpfcat/ebpf.rst b/ebpfcat/ebpf.rst index 570e7c869768deced196013571e067ae9fbd8e57..8184d415aef76e63b2a9e19acd32b69e0a9d8ebc 100644 --- a/ebpfcat/ebpf.rst +++ b/ebpfcat/ebpf.rst @@ -126,6 +126,38 @@ packets:: self.count += 1 self.exit(XDPExitCode.PASS) +as a simplification, if the class attribute ``minimumPacketSize`` is set, +the ``program`` is called within a ``with`` statement like above, and all +the packet variables appear as variables of the object. The class +attribute ``defaultExitCode`` then gives the exit code in case the packet +is too small (by default ``XDPExitCode.PASS``). So the above example becomes:: + + class Program(XDP): + minimumPacketSize = 16 + userspace = HashMap() + count = userspace.globalVar() + + def program(self): + with self.pH[12] == 8: + self.count += 1 + +With the ``PacketVar`` descriptor it is possible to declare certain positions +in the packet as variables. As parameters it takes the position within the +packet, and the data format, following the conventions from the Python +``struct`` package, including the endianness markers ``<>!``. So the above +example simplifies to:: + + class Program(XDP): + minimumPacketSize = 16 + userspace = HashMap() + count = userspace.globalVar() + etherType = PacketVar(12, "!H") # use network byte order + + def program(self): + with self.etherType == 0x800: + self.count += 1 + + Maps ---- diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py index c61bb959b4ec94b3d99f0f7bca470aeeaa9bf53b..dbe2cdb8fe64d83df298a824d8fbfdb2dba1ab71 100644 --- a/ebpfcat/ebpf_test.py +++ b/ebpfcat/ebpf_test.py @@ -23,7 +23,7 @@ from .ebpf import ( AssembleError, EBPF, FuncId, Opcode, OpcodeFlags, Opcode as O, LocalVar, SubProgram, ktime) from .hashmap import HashMap -from .xdp import XDP +from .xdp import XDP, PacketVar from .bpf import ProgType, prog_test_run @@ -192,7 +192,17 @@ class Tests(TestCase): e.r1 = e.r2 // e.x3 e.x4 = e.x5 // e.x6 - self.maxDiff = None + e.x1 = 3 / e.r2 + e.x3 = 3.5 / e.r4 + e.x5 = 3 / e.x6 + e.x4 = 4.5 / e.x6 + + e.x1 = 3 // e.r2 + e.x3 = 3.5 // e.r4 + e.x5 = 3 // e.x6 + e.x4 = 4.5 // e.x6 + + self.assertEqual(e.opcodes, [ Instruction(opcode=O.REG+O.MOV+O.LONG, dst=1, src=2, off=0, imm=0), Instruction(opcode=O.ADD+O.LONG, dst=1, src=0, off=0, imm=3), @@ -269,7 +279,30 @@ class Tests(TestCase): Instruction(opcode=O.DIV+O.LONG+O.REG, dst=1, src=3, off=0, imm=0), Instruction(opcode=O.LONG+O.REG+O.MOV, dst=4, src=5, off=0, imm=0), Instruction(opcode=O.DIV+O.LONG+O.REG, dst=4, src=6, off=0, imm=0), - Instruction(opcode=O.MUL+O.LONG, dst=4, src=0, off=0, imm=100000), + Instruction(opcode=O.LONG+O.MUL, dst=4, src=0, off=0, imm=100000), + + Instruction(opcode=O.LONG+O.MOV, dst=1, src=0, off=0, imm=300000), + Instruction(opcode=O.DIV+O.REG+O.LONG, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.LONG+O.MOV, dst=3, src=0, off=0, imm=350000), + Instruction(opcode=O.DIV+O.REG+O.LONG, dst=3, src=4, off=0, imm=0), + Instruction(opcode=O.DW, dst=5, src=0, off=0, imm=4230196224), + Instruction(opcode=O.W, dst=0, src=0, off=0, imm=6), + Instruction(opcode=O.DIV+O.REG+O.LONG, dst=5, src=6, off=0, imm=0), + Instruction(opcode=O.DW, dst=4, src=0, off=0, imm=2050327040), + Instruction(opcode=O.W, dst=0, src=0, off=0, imm=10), + Instruction(opcode=O.DIV+O.REG+O.LONG, dst=4, src=6, off=0, imm=0), + Instruction(opcode=O.LONG+O.MOV, dst=1, src=0, off=0, imm=3), + Instruction(opcode=O.DIV+O.REG+O.LONG, dst=1, src=2, off=0, imm=0), + Instruction(opcode=O.LONG+O.MUL, dst=1, src=0, off=0, imm=100000), + Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=3), + Instruction(opcode=O.REG+O.LONG+O.DIV, dst=3, src=4, off=0, imm=0), + Instruction(opcode=O.LONG+O.MUL, dst=3, src=0, off=0, imm=100000), + Instruction(opcode=O.LONG+O.MOV, dst=5, src=0, off=0, imm=300000), + Instruction(opcode=O.DIV+O.REG+O.LONG, dst=5, src=6, off=0, imm=0), + Instruction(opcode=O.LONG+O.MUL, dst=5, src=0, off=0, imm=100000), + Instruction(opcode=O.LONG+O.MOV, dst=4, src=0, off=0, imm=450000), + Instruction(opcode=O.DIV+O.REG+O.LONG, dst=4, src=6, off=0, imm=0), + Instruction(opcode=O.LONG+O.MUL, dst=4, src=0, off=0, imm=100000), ]) def test_local(self): @@ -317,6 +350,8 @@ class Tests(TestCase): with e.b: e.a = 0 + e.a = e.b + self.assertEqual(e.opcodes, [ Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0), Instruction(opcode=O.JSET, dst=0, src=0, off=1, imm=32), @@ -344,6 +379,14 @@ class Tests(TestCase): Instruction(opcode=O.JMP, dst=0, src=0, off=3, imm=0), Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0), Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=-33), + Instruction(opcode=O.STX+O.B, dst=10, src=0, off=-1, imm=0), + Instruction(opcode=O.LD+O.B, dst=2, src=10, off=-2, imm=0), + Instruction(opcode=O.JSET, dst=2, src=0, off=3, imm=120), + Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0), + Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=-33), + Instruction(opcode=O.JMP, dst=0, src=0, off=2, imm=0), + Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0), + Instruction(opcode=O.OR, dst=0, src=0, off=0, imm=32), Instruction(opcode=O.B+O.STX, dst=10, src=0, off=-1, imm=0)]) def test_local_subprog(self): @@ -374,6 +417,7 @@ class Tests(TestCase): a = LocalVar('I') b = LocalVar('q') c = LocalVar('h') + d = LocalVar('x') e = Local(ProgType.XDP, "GPL") e.a += 3 @@ -386,6 +430,9 @@ class Tests(TestCase): e.c += 3 e.mB[e.r1] += e.r1 + e.d -= 5 + e.d += e.r1 + self.assertEqual(e.opcodes, [ Instruction(opcode=O.LONG+O.MOV, dst=0, src=0, off=0, imm=3), Instruction(opcode=O.XADD+O.W, dst=10, src=0, off=-4, imm=0), @@ -400,7 +447,13 @@ class Tests(TestCase): Instruction(opcode=O.STX+O.REG, dst=10, src=0, off=-18, imm=0), Instruction(opcode=O.B+O.LD, dst=0, src=1, off=0, imm=0), Instruction(opcode=O.ADD+O.REG, dst=0, src=1, off=0, imm=0), - Instruction(opcode=O.STX+O.B, dst=1, src=0, off=0, imm=0)]) + Instruction(opcode=O.STX+O.B, dst=1, src=0, off=0, imm=0), + Instruction(opcode=O.LONG+O.MOV, dst=0, src=0, off=0, imm=-500000), + Instruction(opcode=O.XADD+O.DW, dst=10, src=0, off=-32, imm=0), + Instruction(opcode=O.REG+O.LONG+O.MOV, dst=0, src=1, off=0, imm=0), + Instruction(opcode=O.MUL+O.LONG, dst=0, src=0, off=0, imm=100000), + Instruction(opcode=O.XADD+O.DW, dst=10, src=0, off=-32, imm=0), + ]) def test_jump(self): @@ -632,7 +685,9 @@ class Tests(TestCase): Instruction(opcode=O.DW, dst=0, src=0, off=0, imm=878082192), Instruction(opcode=O.W, dst=0, src=0, off=0, imm=18), Instruction(opcode=O.LONG+O.REG+O.ADD, dst=3, src=0, off=0, imm=0), - Instruction(opcode=O.LONG+O.MOV, dst=3, src=0, off=0, imm=2415919104), + Instruction(opcode=O.DW, dst=3, src=0, off=0, imm=2415919104), + Instruction(opcode=O.W, dst=0, src=0, off=0, imm=0), + ]) def test_simple_binary(self): @@ -809,12 +864,12 @@ class Tests(TestCase): Instruction(opcode=39, dst=0, src=0, off=0, imm=2), Instruction(opcode=31, dst=3, src=0, off=0, imm=0), Instruction(opcode=191, dst=0, src=3, off=0, imm=0), - Instruction(opcode=O.MUL+O.LONG, dst=0, src=0, off=0, imm=2), + Instruction(opcode=O.MUL, dst=0, src=0, off=0, imm=2), Instruction(opcode=107, dst=10, src=0, off=-10, imm=0), Instruction(opcode=191, dst=0, src=10, off=0, imm=0), Instruction(opcode=15, dst=0, src=3, off=0, imm=0), Instruction(opcode=191, dst=2, src=3, off=0, imm=0), - Instruction(opcode=O.MUL+O.LONG, dst=2, src=0, off=0, imm=2), + Instruction(opcode=O.MUL, dst=2, src=0, off=0, imm=2), Instruction(opcode=107, dst=0, src=2, off=0, imm=0), Instruction(opcode=191, dst=5, src=10, off=0, imm=0), @@ -899,6 +954,81 @@ class Tests(TestCase): Instruction(opcode=O.JMP, dst=0, src=0, off=1, imm=0), Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=77)]) + def test_endian(self): + class P(XDP): + minimumPacketSize = 100 + + ph = PacketVar(20, "<H") + pi = PacketVar(28, ">i") + pq = PacketVar(36, "!q") + + pp = PacketVar(100, "Q") + + def program(self): + self.ph = 3 + self.pi = 5 + self.pq = 7 + + self.ph += 3 + self.pi += 5 + self.pq = self.ph + + e = P(license="GPL") + e.assemble() + self.assertEqual(e.opcodes, [ + Instruction(opcode=O.W+O.LD, dst=9, src=1, off=0, imm=0), + Instruction(opcode=O.W+O.LD, dst=0, src=1, off=4, imm=0), + Instruction(opcode=O.W+O.LD, dst=2, src=1, off=0, imm=0), + Instruction(opcode=O.LONG+O.ADD, dst=2, src=0, off=0, imm=100), + Instruction(opcode=O.JLE+O.REG, dst=0, src=2, off=19, imm=0), + Instruction(opcode=O.ST+O.REG, dst=9, src=0, off=20, imm=3), + Instruction(opcode=O.W+O.ST, dst=9, src=0, off=28, imm=83886080), + Instruction(opcode=O.DW, dst=0, src=0, off=0, imm=0), + Instruction(opcode=O.W, dst=0, src=0, off=0, imm=117440512), + Instruction(opcode=O.DW+O.STX, dst=9, src=0, off=36, imm=0), + Instruction(opcode=O.LD+O.REG, dst=0, src=9, off=20, imm=0), + Instruction(opcode=O.LE, dst=0, src=0, off=0, imm=16), + Instruction(opcode=O.ADD, dst=0, src=0, off=0, imm=3), + Instruction(opcode=O.LE, dst=0, src=0, off=0, imm=16), + Instruction(opcode=O.REG+O.STX, dst=9, src=0, off=20, imm=0), + Instruction(opcode=O.W+O.LD, dst=0, src=9, off=28, imm=0), + Instruction(opcode=O.BE, dst=0, src=0, off=0, imm=32), + Instruction(opcode=O.ADD, dst=0, src=0, off=0, imm=5), + Instruction(opcode=O.BE, dst=0, src=0, off=0, imm=32), + Instruction(opcode=O.W+O.STX, dst=9, src=0, off=28, imm=0), + Instruction(opcode=O.LD+O.REG, dst=0, src=9, off=20, imm=0), + Instruction(opcode=O.LE, dst=0, src=0, off=0, imm=16), + Instruction(opcode=O.BE, dst=0, src=0, off=0, imm=64), + Instruction(opcode=O.DW+O.STX, dst=9, src=0, off=36, imm=0), + Instruction(opcode=O.LONG+O.MOV, dst=0, src=0, off=0, imm=2), + Instruction(opcode=O.EXIT, dst=0, src=0, off=0, imm=0), + ]) + + + def test_xdp_minsize(self): + class P(XDP): + minimumPacketSize = 100 + + pv = PacketVar(20, "H") + + def program(self): + self.pv = self.pH[22] + + p = P(license="GPL") + p.assemble() + self.assertEqual(p.opcodes, [ + Instruction(opcode=O.W+O.LD, dst=9, src=1, off=0, imm=0), + Instruction(opcode=O.W+O.LD, dst=0, src=1, off=4, imm=0), + Instruction(opcode=O.W+O.LD, dst=2, src=1, off=0, imm=0), + Instruction(opcode=O.LONG+O.ADD, dst=2, src=0, off=0, imm=100), + Instruction(opcode=O.JLE+O.REG, dst=0, src=2, off=2, imm=0), + Instruction(opcode=O.REG+O.LD, dst=0, src=9, off=22, imm=0), + Instruction(opcode=O.REG+O.STX, dst=9, src=0, off=20, imm=0), + Instruction(opcode=O.LONG+O.MOV, dst=0, src=0, off=0, imm=2), + Instruction(opcode=O.EXIT, dst=0, src=0, off=0, imm=0), + ]) + + class KernelTests(TestCase): def test_hashmap(self): class Global(EBPF): diff --git a/ebpfcat/ebpfcat.py b/ebpfcat/ebpfcat.py index 5bb9e438ebf8d325500edbcaf4d67694ff42ce87..77a9abb6575f08b32ad1b85e5ae114b9e0da8815 100644 --- a/ebpfcat/ebpfcat.py +++ b/ebpfcat/ebpfcat.py @@ -24,7 +24,7 @@ from time import time from .arraymap import ArrayMap, ArrayGlobalVarDesc from .ethercat import ECCmd, EtherCat, Packet, Terminal from .ebpf import FuncId, MemoryDesc, SubProgram, prandom -from .xdp import XDP, XDPExitCode +from .xdp import XDP, XDPExitCode, PacketVar as XDPPacketVar from .bpf import ( ProgType, MapType, create_map, delete_elem, update_elem, prog_test_run, lookup_elem) @@ -255,42 +255,44 @@ class EBPFTerminal(Terminal): class EtherXDP(XDP): license = "GPL" + minimumPacketSize = 30 variables = ArrayMap() dropcounter = variables.globalVar("I") counters = variables.globalVar("64I") rate = 0 + DATA0 = 26 - def program(self): - ETHERTYPE = 12 - CMD0 = 16 - ADDR0 = 18 + ethertype = XDPPacketVar(12, "!H") + addr0 = XDPPacketVar(18, "I") + cmd0 = XDPPacketVar(16, "B") + data0 = XDPPacketVar(DATA0, "H") + def program(self): with prandom(self.ebpf) & 0xffff < self.rate: self.dropcounter += 1 self.ebpf.exit(XDPExitCode.DROP) - with self.packetSize > 30 as p, p.pH[ETHERTYPE] == 0xA488, \ - p.pB[CMD0] == 0: - self.r3 = p.pI[ADDR0] # use r3 for tail_call + with self.ethertype == 0x88A4, self.cmd0 == 0: + self.r3 = self.addr0 # use r3 for tail_call with self.counters.get_address(None, False, False) as (dst, _), \ self.r3 < FastEtherCat.MAX_PROGS: self.r[dst] += 4 * self.r3 self.r4 = self.mH[self.r[dst]] # we lost a packet - with p.pH[self.DATA0] == self.r4 as Else: + with self.data0 == self.r4 as Else: self.mI[self.r[dst]] += 1 + (self.r4 & 1) # normal case: two packets on the wire - with Else, ((p.pH[self.DATA0] + 1 & 0xffff) == self.r4) \ - | (p.pH[self.DATA0] == 0) as Else: + with Else, ((self.data0 + 1 & 0xffff) == self.r4) \ + | (self.data0 == 0) as Else: self.mI[self.r[dst]] += 1 with self.r4 & 1: # last one was active - p.pH[self.DATA0] = self.mH[self.r[dst]] + self.data0 = self.mH[self.r[dst]] self.exit(XDPExitCode.TX) with Else: self.exit(XDPExitCode.PASS) - p.pH[self.DATA0] = self.mH[self.r[dst]] + self.data0 = self.mH[self.r[dst]] self.r2 = self.get_fd(self.programs) self.call(FuncId.tail_call) self.exit(XDPExitCode.PASS) diff --git a/ebpfcat/util.py b/ebpfcat/util.py new file mode 100644 index 0000000000000000000000000000000000000000..840e9ec52d6ffaff6fbc60df6a3ce7bc4ac98816 --- /dev/null +++ b/ebpfcat/util.py @@ -0,0 +1,43 @@ +from itertools import chain + + +class sub: + def __init__(self, cls, base, default=False): + self.cls = cls + self.base = base + + def __getattr__(self, name): + mro = self.base.__class__.__mro__[::-1] + i = mro.index(self.cls) + for cls in chain(mro[i + 1 :], mro[:i + 1]): + func = cls.__dict__.get(name) + if func is not None: + return func.__get__(self.base, cls) + raise AttributeError(f"'sub' object has no attribute '{name}'") + + +if __name__ == "__main__": + class A: + def g(self): + print("A.f") + + class B(A): + def f(self): + print("B.f") + + class C(A): + def f(self): + print("C.f") + + class D(C, B): + def f(self): + print("D.f") + + + b = D() + print(D.__mro__) + sub(A, b).f() + sub(B, b).f() + sub(C, b).f() + sub(D, b).f() + diff --git a/ebpfcat/xdp.py b/ebpfcat/xdp.py index dee0e1e7ee1b3187f1c06405b71b20c5e85e360e..0d2b518ae20a81f787d9e469c9666095b31797a8 100644 --- a/ebpfcat/xdp.py +++ b/ebpfcat/xdp.py @@ -24,8 +24,9 @@ from socket import AF_NETLINK, NETLINK_ROUTE, if_nametoindex import socket from struct import pack, unpack -from .ebpf import EBPF +from .ebpf import EBPF, MemoryDesc from .bpf import ProgType +from .util import sub class XDPExitCode(Enum): @@ -151,13 +152,39 @@ class PacketSize: return self > value - 1 +class PacketVar(MemoryDesc): + base_register = 9 + + def __init__(self, address, fmt): + self.address = address + self.fmt = fmt + + def fmt_addr(self, instance): + return self.fmt, self.address + + class XDP(EBPF): """the base class for XDP programs""" + minimumPacketSize = None + defaultExitCode = XDPExitCode.PASS + def __init__(self, **kwargs): super().__init__(prog_type=ProgType.XDP, **kwargs) self.packetSize = PacketSize(self) + def program(self): + if self.minimumPacketSize is None: + sub(XDP, self).program() + else: + with self.packetSize > self.minimumPacketSize as packet: + self.pB = packet.pB + self.pH = packet.pH + self.pI = packet.pI + self.pQ = packet.pQ + sub(XDP, self).program() + self.exit(self.defaultExitCode) + async def _netlink(self, ifindex, fd, flags): future = Future() transport, proto = await get_event_loop().create_datagram_endpoint(