From 5c88145829e36681fe33ca183a09f64fcbea1349 Mon Sep 17 00:00:00 2001 From: Martin Teichmann <martin.teichmann@gmail.com> Date: Sun, 12 Feb 2023 13:25:07 +0000 Subject: [PATCH] factor out signedness from calculate this can easier be done earlier --- ebpfcat/ebpf.py | 157 ++++++++++++++++++++++----------------------- ebpfcat/hashmap.py | 6 +- 2 files changed, 77 insertions(+), 86 deletions(-) diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py index acd1bae..bf4c9fd 100644 --- a/ebpfcat/ebpf.py +++ b/ebpfcat/ebpf.py @@ -268,8 +268,11 @@ class AssembleError(Exception): def comparison(uposop, unegop, sposop, snegop): def ret(self, value): - return SimpleComparison(self.ebpf, self, value, - (uposop, unegop, sposop, snegop)) + if self.signed or ((value < 0) if isinstance(value, int) + else value.signed): + return SimpleComparison(self.ebpf, self, value, (sposop, snegop)) + else: + return SimpleComparison(self.ebpf, self, value, (uposop, unegop)) return ret @@ -356,16 +359,14 @@ class SimpleComparison(Comparison): self.opcode = opcode def compare(self, negative): - with self.left.calculate(None, None, None) as (self.dst, _, lsigned): + with self.left.calculate(None, None) as (self.dst, _): with ExitStack() as exitStack: - if isinstance(self.right, int): - rsigned = (self.right < 0) - else: - self.src, _, rsigned = exitStack.enter_context( - self.right.calculate(None, None, None)) + if not isinstance(self.right, int): + self.src, _ = exitStack.enter_context( + self.right.calculate(None, None)) self.origin = len(self.ebpf.opcodes) self.ebpf.opcodes.append(None) - self.opcode = self.opcode[negative + 2 * (lsigned or rsigned)] + self.opcode = self.opcode[negative] self.owners = self.ebpf.owners.copy() def target(self, retarget=False): @@ -424,12 +425,14 @@ class InvertComparison(Comparison): def binary(opcode): def ret(self, value): - return Binary(self.ebpf, self, value, opcode) + return Binary(self.ebpf, self, value, opcode, + self.signed or ((value < 0) if isinstance(value, int) + else value.signed), False) return ret def rbinary(opcode): def ret(self, value): - return ReverseBinary(self.ebpf, value, self, opcode) + return ReverseBinary(self.ebpf, value, self, opcode, value < 0, False) return ret @@ -444,7 +447,6 @@ class Expression: __ror__ = __or__ = binary(Opcode.OR) __lshift__ = binary(Opcode.LSH) __rlshift__ = rbinary(Opcode.LSH) - __rshift__ = binary(Opcode.RSH) __rrshift__ = rbinary(Opcode.RSH) __mod__ = binary(Opcode.MOD) __rmod__ = rbinary(Opcode.MOD) @@ -455,6 +457,14 @@ class Expression: __lt__ = comparison(Opcode.JLT, Opcode.JGE, Opcode.JSLT, Opcode.JSGE) __le__ = comparison(Opcode.JLE, Opcode.JGT, Opcode.JSLE, Opcode.JSGT) + def __rshift__(self, value): + opcode = Opcode.ARSH if self.signed else Opcode.RSH + return Binary(self.ebpf, self, value, opcode, self.signed, False) + + def __rrshift__(self, value): + opcode = Opcode.ARSH if value < 0 else Opcode.RSH + return ReverseBinary(self.ebpf, value, self, opcode, value < 0, False) + def __and__(self, value): return AndExpression(self.ebpf, self, value) @@ -486,15 +496,13 @@ class Expression: return self.as_comparison.__exit__(exc_type, exc, tb) @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): """issue the code that calculates the value of this expression this method returns three values: - the number of the register with the result - a boolean indicating whether this is a 64 bit value - - and a booleand indicating whether the result is to be - considered signed. this method is a contextmanager to be used in a `with` statement. At the end of the `with` block the result is @@ -508,19 +516,17 @@ class Expression: or `None` if that does not matter. :param long: True if the result is supposed to be 64 bit. None if it does not matter. - :param signed: True if the result should be considered signed. - None if it does not matter. :param force: if true, `dst` must be respected, otherwise this is optional. """ with self.ebpf.get_free_register(dst) as dst: - with self.get_address(dst, long, signed) as (src, fmt): + with self.get_address(dst, long) as (src, fmt): self.ebpf.append(Opcode.LD + Memory.fmt_to_opcode[fmt], dst, src, 0, 0) - yield dst, long, self.signed + yield dst, long @contextmanager - def get_address(self, dst, long, signed, force=False): + def get_address(self, dst, long, force=False): """get the address of the value of this expression this method returns the address of the result of this expression, @@ -529,7 +535,7 @@ class Expression: the stack. """ with self.ebpf.get_stack(4 + 4 * long) as stack: - with self.calculate(dst, long, signed) as (src, _, _): + with self.calculate(dst, long) as (src, _): self.ebpf.append(Opcode.STX + Opcode.DW * long, 10, src, stack, 0) self.ebpf.append(Opcode.MOV + Opcode.LONG + Opcode.REG, @@ -544,44 +550,38 @@ class Expression: class Binary(Expression): """represent all binary expressions""" - def __init__(self, ebpf, left, right, operator): + def __init__(self, ebpf, left, right, operator, signed, fixed): self.ebpf = ebpf self.left = left self.right = right self.operator = operator + self.signed = signed + self.fixed = fixed @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): orig_dst = dst if not isinstance(self.right, int) and self.right.contains(dst): dst = None with self.ebpf.get_free_register(dst) as dst: - with self.left.calculate(dst, long, signed, True) \ - as (dst, l_long, l_signed): + with self.left.calculate(dst, long, True) as (dst, l_long): if long is None: long = l_long - signed = signed or l_signed - if self.operator is Opcode.RSH and signed: # >>= - operator = Opcode.ARSH - else: - operator = self.operator if isinstance(self.right, int): - r_signed = self.right < 0 - self.ebpf.append(operator + Opcode.LONG * long, + self.ebpf.append(self.operator + Opcode.LONG * long, dst, 0, 0, self.right) else: - with self.right.calculate(None, long, None) as \ - (src, r_long, r_signed): + with self.right.calculate(None, long) as (src, r_long): self.ebpf.append( - operator + Opcode.REG + self.operator + Opcode.REG + Opcode.LONG * ((r_long or l_long) if long is None else long), dst, src, 0, 0) if orig_dst is None or orig_dst == dst: - yield dst, long, signed or r_signed + yield dst, long return self.ebpf.append(Opcode.MOV + Opcode.REG + Opcode.LONG * long, orig_dst, dst, 0, 0) - yield orig_dst, long, signed or r_signed + yield orig_dst, long def contains(self, no): return self.left.contains(no) or (not isinstance(self.right, int) @@ -589,25 +589,22 @@ class Binary(Expression): class ReverseBinary(Expression): - def __init__(self, ebpf, left, right, operator): + def __init__(self, ebpf, left, right, operator, signed, fixed): self.ebpf = ebpf self.left = left self.right = right self.operator = operator + self.signed = signed + self.fixed = fixed @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): with self.ebpf.get_free_register(dst) as dst: self.ebpf._load_value(dst, self.left) - if self.operator is Opcode.RSH and self.left < 0: # >>= - operator = Opcode.ARSH - else: - operator = self.operator - - with self.right.calculate(None, long, None) as (src, long, _): - self.ebpf.append(operator + Opcode.LONG * long + Opcode.REG, - dst, src, 0, 0) - yield dst, long, signed + with self.right.calculate(None, long) as (src, long): + self.ebpf.append(self.operator + Opcode.LONG * long + + Opcode.REG, dst, src, 0, 0) + yield dst, long def contains(self, no): return self.right.contains(no) @@ -619,11 +616,10 @@ class Negate(Expression): self.arg = arg @contextmanager - def calculate(self, dst, long, signed, force=False): - with self.arg.calculate(dst, long, signed, force) as \ - (dst, long, signed): + def calculate(self, dst, long, force=False): + with self.arg.calculate(dst, long, force) as (dst, long): self.ebpf.append(Opcode.NEG + Opcode.LONG * long, dst, 0, 0, 0) - yield dst, long, signed + yield dst, long def contains(self, no): return self.arg.contains(no) @@ -635,12 +631,11 @@ class Absolute(Expression): self.arg = arg @contextmanager - def calculate(self, dst, long, signed, force=False): - with self.arg.calculate(dst, long, True, force) as \ - (dst, long, signed): + def calculate(self, dst, long, force=False): + with self.arg.calculate(dst, long, force) as (dst, long): with self.ebpf.r[dst] < 0: self.ebpf.r[dst] = -self.ebpf.r[dst] - yield dst, long, True + yield dst, long def contains(self, no): return self.arg.contains(no) @@ -652,7 +647,7 @@ class Sum(Binary): this is used to optimize memory addressing code. """ def __init__(self, ebpf, left, right): - super().__init__(ebpf, left, right, Opcode.ADD) + super().__init__(ebpf, left, right, Opcode.ADD, right < 0, False) def __add__(self, value): if isinstance(value, int): @@ -672,7 +667,7 @@ class Sum(Binary): class AndExpression(Binary): # there is a special comparison with & instruction def __init__(self, ebpf, left, right): - super().__init__(ebpf, left, right, Opcode.AND) + super().__init__(ebpf, left, right, Opcode.AND, False, False) def __ne__(self, value): if isinstance(value, int) and value == 0: @@ -684,7 +679,7 @@ class AndComparison(SimpleComparison): # there is a special comparison with & instruction # it is the only one which has not inversion def __init__(self, ebpf, left, right): - Binary.__init__(self, ebpf, left, right, Opcode.AND) + Binary.__init__(self, ebpf, left, right, Opcode.AND, False, False) SimpleComparison.__init__(self, ebpf, left, right, Opcode.JSET) self.opcode = (Opcode.JSET, None, Opcode.JSET, None) self.invert = None @@ -747,15 +742,15 @@ class Register(Expression): return super().__sub__(value) @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): if self.no not in self.ebpf.owners: raise AssembleError("register has no value") if dst != self.no and force: self.ebpf.append(Opcode.MOV + Opcode.REG + Opcode.LONG * self.long, dst, self.no, 0, 0) - yield dst, self.long, self.signed + yield dst, self.long else: - yield self.no, self.long, self.signed + yield self.no, self.long def contains(self, no): return self.no == no @@ -791,7 +786,7 @@ class Memory(Expression): return NotImplemented @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): with ExitStack() as exitStack: if isinstance(self.address, Sum): dst = exitStack.enter_context(self.ebpf.get_free_register(dst)) @@ -799,19 +794,19 @@ class Memory(Expression): Opcode.LD + self.fmt_to_opcode.get(self.fmt, Opcode.B), dst, self.address.left.no, self.address.right, 0) else: - dst, _, _ = exitStack.enter_context( - super().calculate(dst, long, signed, force)) + dst, _ = exitStack.enter_context( + super().calculate(dst, long, force)) if isinstance(self.fmt, tuple): self.ebpf.r[dst] &= ((1 << self.fmt[1]) - 1) << self.fmt[0] if self.fmt[0] > 0: self.ebpf.r[dst] >>= self.fmt[0] - yield dst, "B", False + yield dst, "B" else: - yield dst, self.fmt in "QqA", self.fmt.islower() + yield dst, self.fmt in "QqA" @contextmanager - def get_address(self, dst, long, signed, force=False): - with self.address.calculate(dst, True, None) as (src, _, _): + def get_address(self, dst, long, force=False): + with self.address.calculate(dst, True) as (src, _): yield src, self.fmt def contains(self, no): @@ -885,9 +880,8 @@ class MemoryDesc: opcode = Opcode.XADD else: opcode = Opcode.STX - with value.calculate(None, isinstance(fmt, str) and fmt in 'qQ', - isinstance(fmt, str) and fmt.islower() - ) as (src, _, _): + with value.calculate(None, isinstance(fmt, str) and fmt in 'qQ' + ) as (src, _): ebpf.append(opcode + bits, self.base_register, src, addr, 0) @@ -926,8 +920,7 @@ class MemoryMap: dst = addr.left.no offset = addr.right else: - dst, _, _ = exitStack.enter_context( - addr.calculate(None, True, None)) + dst, _ = exitStack.enter_context(addr.calculate(None, True)) offset = 0 if isinstance(value, int): self.ebpf.append(Opcode.ST + Memory.fmt_to_opcode[self.fmt], @@ -945,7 +938,7 @@ class MemoryMap: opcode = Opcode.XADD else: opcode = Opcode.STX - with value.calculate(None, None, None) as (src, _, _): + with value.calculate(None, None) as (src, _): self.ebpf.append(opcode + Memory.fmt_to_opcode[self.fmt], dst, src, offset, 0) @@ -972,11 +965,11 @@ class PseudoFd(Expression): self.fd = fd @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): with self.ebpf.get_free_register(dst) as dst: self.ebpf.append(Opcode.DW, dst, 1, 0, self.fd) self.ebpf.append(Opcode.W, 0, 0, 0, 0) - yield dst, long, signed + yield dst, long class ktime(Expression): @@ -985,13 +978,13 @@ class ktime(Expression): self.ebpf = ebpf @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): with self.ebpf.get_free_register(dst) as dst: with self.ebpf.save_registers([i for i in range(6) if i != dst]): self.ebpf.call(FuncId.ktime_get_ns) if dst != 0: self.ebpf.r[dst] = self.ebpf.r0 - yield dst, True, False + yield dst, True class prandom(Expression): @@ -1000,13 +993,13 @@ class prandom(Expression): self.ebpf = ebpf @contextmanager - def calculate(self, dst, long, signed, force=False): + def calculate(self, dst, long, force=False): with self.ebpf.get_free_register(dst) as dst: with self.ebpf.save_registers([i for i in range(6) if i != dst]): self.ebpf.call(FuncId.get_prandom_u32) if dst != 0: self.ebpf.r[dst] = self.ebpf.r0 - yield dst, True, False + yield dst, True class RegisterDesc: @@ -1035,7 +1028,7 @@ class RegisterArray: if isinstance(value, int): self.ebpf._load_value(no, value) elif isinstance(value, Expression): - with value.calculate(no, self.long, self.signed, True): + with value.calculate(no, self.long, True): pass else: raise AssembleError("cannot compile") diff --git a/ebpfcat/hashmap.py b/ebpfcat/hashmap.py index 76f8518..aacea52 100644 --- a/ebpfcat/hashmap.py +++ b/ebpfcat/hashmap.py @@ -30,9 +30,7 @@ class HashGlobalVar(Expression): self.signed = fmt.islower() @contextmanager - def get_address(self, dst, long, signed, force=False): - if signed != self.fmt.islower(): - raise AssembleError("HashMap variable has wrong signedness") + def get_address(self, dst, long, force=False): with self.ebpf.save_registers([i for i in range(6) if i != dst]), \ self.ebpf.get_stack(4) as stack: self.ebpf.append(Opcode.ST, 10, 0, stack, self.count) @@ -78,7 +76,7 @@ class HashGlobalVarDesc: pack("q" if self.fmt.islower() else "Q", value), 0) return with ebpf.save_registers([3]): - with value.get_address(3, True, self.fmt.islower(), True): + with value.get_address(3, True, True): with ebpf.save_registers([0, 1, 2, 4, 5]), \ ebpf.get_stack(4) as stack: ebpf.r1 = ebpf.get_fd(ebpf.__dict__[self.name].fd) -- GitLab