From 5c88145829e36681fe33ca183a09f64fcbea1349 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@gmail.com>
Date: Sun, 12 Feb 2023 13:25:07 +0000
Subject: [PATCH] factor out signedness from calculate

this can easier be done earlier
---
 ebpfcat/ebpf.py    | 157 ++++++++++++++++++++++-----------------------
 ebpfcat/hashmap.py |   6 +-
 2 files changed, 77 insertions(+), 86 deletions(-)

diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py
index acd1bae..bf4c9fd 100644
--- a/ebpfcat/ebpf.py
+++ b/ebpfcat/ebpf.py
@@ -268,8 +268,11 @@ class AssembleError(Exception):
 
 def comparison(uposop, unegop, sposop, snegop):
     def ret(self, value):
-        return SimpleComparison(self.ebpf, self, value,
-                                (uposop, unegop, sposop, snegop))
+        if self.signed or ((value < 0) if isinstance(value, int)
+                                       else value.signed):
+            return SimpleComparison(self.ebpf, self, value, (sposop, snegop))
+        else:
+            return SimpleComparison(self.ebpf, self, value, (uposop, unegop))
     return ret
 
 
@@ -356,16 +359,14 @@ class SimpleComparison(Comparison):
         self.opcode = opcode
 
     def compare(self, negative):
-        with self.left.calculate(None, None, None) as (self.dst, _, lsigned):
+        with self.left.calculate(None, None) as (self.dst, _):
             with ExitStack() as exitStack:
-                if isinstance(self.right, int):
-                    rsigned = (self.right < 0)
-                else:
-                    self.src, _, rsigned = exitStack.enter_context(
-                            self.right.calculate(None, None, None))
+                if not isinstance(self.right, int):
+                    self.src, _ = exitStack.enter_context(
+                        self.right.calculate(None, None))
                 self.origin = len(self.ebpf.opcodes)
                 self.ebpf.opcodes.append(None)
-                self.opcode = self.opcode[negative + 2 * (lsigned or rsigned)]
+                self.opcode = self.opcode[negative]
         self.owners = self.ebpf.owners.copy()
 
     def target(self, retarget=False):
@@ -424,12 +425,14 @@ class InvertComparison(Comparison):
 
 def binary(opcode):
     def ret(self, value):
-        return Binary(self.ebpf, self, value, opcode)
+        return Binary(self.ebpf, self, value, opcode,
+                      self.signed or ((value < 0) if isinstance(value, int)
+                                                  else value.signed), False)
     return ret
 
 def rbinary(opcode):
     def ret(self, value):
-        return ReverseBinary(self.ebpf, value, self, opcode)
+        return ReverseBinary(self.ebpf, value, self, opcode, value < 0, False)
     return ret
 
 
@@ -444,7 +447,6 @@ class Expression:
     __ror__ = __or__ = binary(Opcode.OR)
     __lshift__ = binary(Opcode.LSH)
     __rlshift__ = rbinary(Opcode.LSH)
-    __rshift__ = binary(Opcode.RSH)
     __rrshift__ = rbinary(Opcode.RSH)
     __mod__ = binary(Opcode.MOD)
     __rmod__ = rbinary(Opcode.MOD)
@@ -455,6 +457,14 @@ class Expression:
     __lt__ = comparison(Opcode.JLT, Opcode.JGE, Opcode.JSLT, Opcode.JSGE)
     __le__ = comparison(Opcode.JLE, Opcode.JGT, Opcode.JSLE, Opcode.JSGT)
 
+    def __rshift__(self, value):
+        opcode = Opcode.ARSH if self.signed else Opcode.RSH
+        return Binary(self.ebpf, self, value, opcode, self.signed, False)
+
+    def __rrshift__(self, value):
+        opcode = Opcode.ARSH if value < 0 else Opcode.RSH
+        return ReverseBinary(self.ebpf, value, self, opcode, value < 0, False)
+
     def __and__(self, value):
         return AndExpression(self.ebpf, self, value)
 
@@ -486,15 +496,13 @@ class Expression:
         return self.as_comparison.__exit__(exc_type, exc, tb)
 
     @contextmanager
-    def calculate(self, dst, long, signed, force=False):
+    def calculate(self, dst, long, force=False):
         """issue the code that calculates the value of this expression
 
         this method returns three values:
 
         - the number of the register with the result
         - a boolean indicating whether this is a 64 bit value
-        - and a booleand indicating whether the result is to be
-          considered signed.
 
         this method is a contextmanager to be used in a `with`
         statement. At the end of the `with` block the result is
@@ -508,19 +516,17 @@ class Expression:
            or `None` if that does not matter.
         :param long: True if the result is supposed to be 64 bit. None
            if it does not matter.
-        :param signed: True if the result should be considered signed.
-           None if it does not matter.
         :param force: if true, `dst` must be respected, otherwise this
            is optional.
         """
         with self.ebpf.get_free_register(dst) as dst:
-            with self.get_address(dst, long, signed) as (src, fmt):
+            with self.get_address(dst, long) as (src, fmt):
                 self.ebpf.append(Opcode.LD + Memory.fmt_to_opcode[fmt],
                                  dst, src, 0, 0)
-                yield dst, long, self.signed
+                yield dst, long
 
     @contextmanager
-    def get_address(self, dst, long, signed, force=False):
+    def get_address(self, dst, long, force=False):
         """get the address of the value of this expression
 
         this method returns the address of the result of this expression,
@@ -529,7 +535,7 @@ class Expression:
         the stack.
         """
         with self.ebpf.get_stack(4 + 4 * long) as stack:
-            with self.calculate(dst, long, signed) as (src, _, _):
+            with self.calculate(dst, long) as (src, _):
                 self.ebpf.append(Opcode.STX + Opcode.DW * long,
                                  10, src, stack, 0)
                 self.ebpf.append(Opcode.MOV + Opcode.LONG + Opcode.REG,
@@ -544,44 +550,38 @@ class Expression:
 
 class Binary(Expression):
     """represent all binary expressions"""
-    def __init__(self, ebpf, left, right, operator):
+    def __init__(self, ebpf, left, right, operator, signed, fixed):
         self.ebpf = ebpf
         self.left = left
         self.right = right
         self.operator = operator
+        self.signed = signed
+        self.fixed = fixed
 
     @contextmanager
-    def calculate(self, dst, long, signed, force=False):
+    def calculate(self, dst, long, force=False):
         orig_dst = dst
         if not isinstance(self.right, int) and self.right.contains(dst):
             dst = None
         with self.ebpf.get_free_register(dst) as dst:
-            with self.left.calculate(dst, long, signed, True) \
-                    as (dst, l_long, l_signed):
+            with self.left.calculate(dst, long, True) as (dst, l_long):
                 if long is None:
                     long = l_long
-                signed = signed or l_signed
-            if self.operator is Opcode.RSH and signed:  # >>=
-                operator = Opcode.ARSH
-            else:
-                operator = self.operator
             if isinstance(self.right, int):
-                r_signed = self.right < 0
-                self.ebpf.append(operator + Opcode.LONG * long,
+                self.ebpf.append(self.operator + Opcode.LONG * long,
                                  dst, 0, 0, self.right)
             else:
-                with self.right.calculate(None, long, None) as \
-                        (src, r_long, r_signed):
+                with self.right.calculate(None, long) as (src, r_long):
                     self.ebpf.append(
-                        operator + Opcode.REG
+                        self.operator + Opcode.REG
                         + Opcode.LONG * ((r_long or l_long)
                                          if long is None else long),
                         dst, src, 0, 0)
             if orig_dst is None or orig_dst == dst:
-                yield dst, long, signed or r_signed
+                yield dst, long
                 return
         self.ebpf.append(Opcode.MOV + Opcode.REG + Opcode.LONG * long, orig_dst, dst, 0, 0)
-        yield orig_dst, long, signed or r_signed
+        yield orig_dst, long
 
     def contains(self, no):
         return self.left.contains(no) or (not isinstance(self.right, int)
@@ -589,25 +589,22 @@ class Binary(Expression):
 
 
 class ReverseBinary(Expression):
-    def __init__(self, ebpf, left, right, operator):
+    def __init__(self, ebpf, left, right, operator, signed, fixed):
         self.ebpf = ebpf
         self.left = left
         self.right = right
         self.operator = operator
+        self.signed = signed
+        self.fixed = fixed
 
     @contextmanager
-    def calculate(self, dst, long, signed, force=False):
+    def calculate(self, dst, long, force=False):
         with self.ebpf.get_free_register(dst) as dst:
             self.ebpf._load_value(dst, self.left)
-            if self.operator is Opcode.RSH and self.left < 0:  # >>=
-                operator = Opcode.ARSH
-            else:
-                operator = self.operator
-
-            with self.right.calculate(None, long, None) as (src, long, _):
-                self.ebpf.append(operator + Opcode.LONG * long + Opcode.REG,
-                                 dst, src, 0, 0)
-            yield dst, long, signed
+            with self.right.calculate(None, long) as (src, long):
+                self.ebpf.append(self.operator + Opcode.LONG * long
+                                 + Opcode.REG, dst, src, 0, 0)
+            yield dst, long
 
     def contains(self, no):
         return self.right.contains(no)
@@ -619,11 +616,10 @@ class Negate(Expression):
         self.arg = arg
 
     @contextmanager
-    def calculate(self, dst, long, signed, force=False):
-        with self.arg.calculate(dst, long, signed, force) as \
-                (dst, long, signed):
+    def calculate(self, dst, long, force=False):
+        with self.arg.calculate(dst, long, force) as (dst, long):
             self.ebpf.append(Opcode.NEG + Opcode.LONG * long, dst, 0, 0, 0)
-            yield dst, long, signed
+            yield dst, long
 
     def contains(self, no):
         return self.arg.contains(no)
@@ -635,12 +631,11 @@ class Absolute(Expression):
         self.arg = arg
 
     @contextmanager
-    def calculate(self, dst, long, signed, force=False):
-        with self.arg.calculate(dst, long, True, force) as \
-                (dst, long, signed):
+    def calculate(self, dst, long, force=False):
+        with self.arg.calculate(dst, long, force) as (dst, long):
             with self.ebpf.r[dst] < 0:
                 self.ebpf.r[dst] = -self.ebpf.r[dst]
-            yield dst, long, True
+            yield dst, long
 
     def contains(self, no):
         return self.arg.contains(no)
@@ -652,7 +647,7 @@ class Sum(Binary):
     this is used to optimize memory addressing code.
     """
     def __init__(self, ebpf, left, right):
-        super().__init__(ebpf, left, right, Opcode.ADD)
+        super().__init__(ebpf, left, right, Opcode.ADD, right < 0, False)
 
     def __add__(self, value):
         if isinstance(value, int):
@@ -672,7 +667,7 @@ class Sum(Binary):
 class AndExpression(Binary):
     # there is a special comparison with & instruction
     def __init__(self, ebpf, left, right):
-        super().__init__(ebpf, left, right, Opcode.AND)
+        super().__init__(ebpf, left, right, Opcode.AND, False, False)
 
     def __ne__(self, value):
         if isinstance(value, int) and value == 0:
@@ -684,7 +679,7 @@ class AndComparison(SimpleComparison):
     # there is a special comparison with & instruction
     # it is the only one which has not inversion
     def __init__(self, ebpf, left, right):
-        Binary.__init__(self, ebpf, left, right, Opcode.AND)
+        Binary.__init__(self, ebpf, left, right, Opcode.AND, False, False)
         SimpleComparison.__init__(self, ebpf, left, right, Opcode.JSET)
         self.opcode = (Opcode.JSET, None, Opcode.JSET, None)
         self.invert = None
@@ -747,15 +742,15 @@ class Register(Expression):
             return super().__sub__(value)
 
     @contextmanager
-    def calculate(self, dst, long, signed, force=False):
+    def calculate(self, dst, long, force=False):
         if self.no not in self.ebpf.owners:
             raise AssembleError("register has no value")
         if dst != self.no and force:
             self.ebpf.append(Opcode.MOV + Opcode.REG + Opcode.LONG * self.long,
                              dst, self.no, 0, 0)
-            yield dst, self.long, self.signed
+            yield dst, self.long
         else:
-            yield self.no, self.long, self.signed
+            yield self.no, self.long
 
     def contains(self, no):
         return self.no == no
@@ -791,7 +786,7 @@ class Memory(Expression):
             return NotImplemented
 
     @contextmanager
-    def calculate(self, dst, long, signed, force=False):
+    def calculate(self, dst, long, force=False):
         with ExitStack() as exitStack:
             if isinstance(self.address, Sum):
                 dst = exitStack.enter_context(self.ebpf.get_free_register(dst))
@@ -799,19 +794,19 @@ class Memory(Expression):
                     Opcode.LD + self.fmt_to_opcode.get(self.fmt, Opcode.B),
                     dst, self.address.left.no, self.address.right, 0)
             else:
-                dst, _, _ = exitStack.enter_context(
-                    super().calculate(dst, long, signed, force))
+                dst, _ = exitStack.enter_context(
+                    super().calculate(dst, long, force))
             if isinstance(self.fmt, tuple):
                 self.ebpf.r[dst] &= ((1 << self.fmt[1]) - 1) << self.fmt[0]
                 if self.fmt[0] > 0:
                     self.ebpf.r[dst] >>= self.fmt[0]
-                yield dst, "B", False
+                yield dst, "B"
             else:
-                yield dst, self.fmt in "QqA", self.fmt.islower()
+                yield dst, self.fmt in "QqA"
 
     @contextmanager
-    def get_address(self, dst, long, signed, force=False):
-        with self.address.calculate(dst, True, None) as (src, _, _):
+    def get_address(self, dst, long, force=False):
+        with self.address.calculate(dst, True) as (src, _):
             yield src, self.fmt
 
     def contains(self, no):
@@ -885,9 +880,8 @@ class MemoryDesc:
             opcode = Opcode.XADD
         else:
             opcode = Opcode.STX
-        with value.calculate(None, isinstance(fmt, str) and fmt in 'qQ',
-                             isinstance(fmt, str) and fmt.islower()
-                            ) as (src, _, _):
+        with value.calculate(None, isinstance(fmt, str) and fmt in 'qQ'
+                            ) as (src, _):
             ebpf.append(opcode + bits, self.base_register, src, addr, 0)
 
 
@@ -926,8 +920,7 @@ class MemoryMap:
                 dst = addr.left.no
                 offset = addr.right
             else:
-                dst, _, _ = exitStack.enter_context(
-                        addr.calculate(None, True, None))
+                dst, _ = exitStack.enter_context(addr.calculate(None, True))
                 offset = 0
             if isinstance(value, int):
                 self.ebpf.append(Opcode.ST + Memory.fmt_to_opcode[self.fmt],
@@ -945,7 +938,7 @@ class MemoryMap:
                 opcode = Opcode.XADD
             else:
                 opcode = Opcode.STX
-            with value.calculate(None, None, None) as (src, _, _):
+            with value.calculate(None, None) as (src, _):
                 self.ebpf.append(opcode + Memory.fmt_to_opcode[self.fmt],
                                  dst, src, offset, 0)
 
@@ -972,11 +965,11 @@ class PseudoFd(Expression):
         self.fd = fd
 
     @contextmanager
-    def calculate(self, dst, long, signed, force=False):
+    def calculate(self, dst, long, force=False):
         with self.ebpf.get_free_register(dst) as dst:
             self.ebpf.append(Opcode.DW, dst, 1, 0, self.fd)
             self.ebpf.append(Opcode.W, 0, 0, 0, 0)
-            yield dst, long, signed
+            yield dst, long
 
 
 class ktime(Expression):
@@ -985,13 +978,13 @@ class ktime(Expression):
         self.ebpf = ebpf
 
     @contextmanager
-    def calculate(self, dst, long, signed, force=False):
+    def calculate(self, dst, long, force=False):
         with self.ebpf.get_free_register(dst) as dst:
             with self.ebpf.save_registers([i for i in range(6) if i != dst]):
                 self.ebpf.call(FuncId.ktime_get_ns)
                 if dst != 0:
                     self.ebpf.r[dst] = self.ebpf.r0
-            yield dst, True, False
+            yield dst, True
 
 
 class prandom(Expression):
@@ -1000,13 +993,13 @@ class prandom(Expression):
         self.ebpf = ebpf
 
     @contextmanager
-    def calculate(self, dst, long, signed, force=False):
+    def calculate(self, dst, long, force=False):
         with self.ebpf.get_free_register(dst) as dst:
             with self.ebpf.save_registers([i for i in range(6) if i != dst]):
                 self.ebpf.call(FuncId.get_prandom_u32)
                 if dst != 0:
                     self.ebpf.r[dst] = self.ebpf.r0
-            yield dst, True, False
+            yield dst, True
 
 
 class RegisterDesc:
@@ -1035,7 +1028,7 @@ class RegisterArray:
         if isinstance(value, int):
             self.ebpf._load_value(no, value)
         elif isinstance(value, Expression):
-            with value.calculate(no, self.long, self.signed, True):
+            with value.calculate(no, self.long, True):
                 pass
         else:
             raise AssembleError("cannot compile")
diff --git a/ebpfcat/hashmap.py b/ebpfcat/hashmap.py
index 76f8518..aacea52 100644
--- a/ebpfcat/hashmap.py
+++ b/ebpfcat/hashmap.py
@@ -30,9 +30,7 @@ class HashGlobalVar(Expression):
         self.signed = fmt.islower()
 
     @contextmanager
-    def get_address(self, dst, long, signed, force=False):
-        if signed != self.fmt.islower():
-            raise AssembleError("HashMap variable has wrong signedness")
+    def get_address(self, dst, long, force=False):
         with self.ebpf.save_registers([i for i in range(6) if i != dst]), \
                 self.ebpf.get_stack(4) as stack:
             self.ebpf.append(Opcode.ST, 10, 0, stack, self.count)
@@ -78,7 +76,7 @@ class HashGlobalVarDesc:
                         pack("q" if self.fmt.islower() else "Q", value), 0)
             return
         with ebpf.save_registers([3]):
-            with value.get_address(3, True, self.fmt.islower(), True):
+            with value.get_address(3, True, True):
                 with ebpf.save_registers([0, 1, 2, 4, 5]), \
                         ebpf.get_stack(4) as stack:
                     ebpf.r1 = ebpf.get_fd(ebpf.__dict__[self.name].fd)
-- 
GitLab