From e53b45cd3207669861b8461474ace91e785acb62 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@gmail.com>
Date: Sun, 12 Feb 2023 14:18:02 +0000
Subject: [PATCH] use isinstance with Expression, not int

---
 ebpfcat/ebpf.py | 92 ++++++++++++++++++++++++++-----------------------
 1 file changed, 48 insertions(+), 44 deletions(-)

diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py
index bf4c9fd..d0a5d05 100644
--- a/ebpfcat/ebpf.py
+++ b/ebpfcat/ebpf.py
@@ -268,8 +268,7 @@ class AssembleError(Exception):
 
 def comparison(uposop, unegop, sposop, snegop):
     def ret(self, value):
-        if self.signed or ((value < 0) if isinstance(value, int)
-                                       else value.signed):
+        if self.signed or issigned(value):
             return SimpleComparison(self.ebpf, self, value, (sposop, snegop))
         else:
             return SimpleComparison(self.ebpf, self, value, (uposop, unegop))
@@ -361,7 +360,7 @@ class SimpleComparison(Comparison):
     def compare(self, negative):
         with self.left.calculate(None, None) as (self.dst, _):
             with ExitStack() as exitStack:
-                if not isinstance(self.right, int):
+                if isinstance(self.right, Expression):
                     self.src, _ = exitStack.enter_context(
                         self.right.calculate(None, None))
                 self.origin = len(self.ebpf.opcodes)
@@ -374,14 +373,14 @@ class SimpleComparison(Comparison):
         if self.opcode == Opcode.JMP:
             inst = Instruction(Opcode.JMP, 0, 0,
                                len(self.ebpf.opcodes) - self.origin - 1, 0)
-        elif isinstance(self.right, int):
-            inst = Instruction(
-                self.opcode, self.dst, 0,
-                len(self.ebpf.opcodes) - self.origin - 1, self.right)
-        else:
+        elif isinstance(self.right, Expression):
             inst = Instruction(
                 self.opcode + Opcode.REG, self.dst, self.src,
                 len(self.ebpf.opcodes) - self.origin - 1, 0)
+        else:
+            inst = Instruction(
+                self.opcode, self.dst, 0,
+                len(self.ebpf.opcodes) - self.origin - 1, self.right)
         self.ebpf.opcodes[self.origin] = inst
         if not retarget:
             self.ebpf.owners, self.owners = \
@@ -423,11 +422,17 @@ class InvertComparison(Comparison):
         self.value.target(retarget)
 
 
+def issigned(value):
+    if isinstance(value, Expression):
+        return value.signed
+    else:
+        return value < 0
+
+
 def binary(opcode):
     def ret(self, value):
         return Binary(self.ebpf, self, value, opcode,
-                      self.signed or ((value < 0) if isinstance(value, int)
-                                                  else value.signed), False)
+                      self.signed or issigned(value), False)
     return ret
 
 def rbinary(opcode):
@@ -561,22 +566,22 @@ class Binary(Expression):
     @contextmanager
     def calculate(self, dst, long, force=False):
         orig_dst = dst
-        if not isinstance(self.right, int) and self.right.contains(dst):
+        if isinstance(self.right, Expression) and self.right.contains(dst):
             dst = None
         with self.ebpf.get_free_register(dst) as dst:
             with self.left.calculate(dst, long, True) as (dst, l_long):
                 if long is None:
                     long = l_long
-            if isinstance(self.right, int):
-                self.ebpf.append(self.operator + Opcode.LONG * long,
-                                 dst, 0, 0, self.right)
-            else:
+            if isinstance(self.right, Expression):
                 with self.right.calculate(None, long) as (src, r_long):
                     self.ebpf.append(
                         self.operator + Opcode.REG
                         + Opcode.LONG * ((r_long or l_long)
                                          if long is None else long),
                         dst, src, 0, 0)
+            else:
+                self.ebpf.append(self.operator + Opcode.LONG * long,
+                                 dst, 0, 0, self.right)
             if orig_dst is None or orig_dst == dst:
                 yield dst, long
                 return
@@ -584,7 +589,7 @@ class Binary(Expression):
         yield orig_dst, long
 
     def contains(self, no):
-        return self.left.contains(no) or (not isinstance(self.right, int)
+        return self.left.contains(no) or (isinstance(self.right, Expression)
                                           and self.right.contains(no))
 
 
@@ -614,6 +619,7 @@ class Negate(Expression):
     def __init__(self, ebpf, arg):
         self.ebpf = ebpf
         self.arg = arg
+        self.signed = True
 
     @contextmanager
     def calculate(self, dst, long, force=False):
@@ -650,18 +656,18 @@ class Sum(Binary):
         super().__init__(ebpf, left, right, Opcode.ADD, right < 0, False)
 
     def __add__(self, value):
-        if isinstance(value, int):
-            return Sum(self.ebpf, self.left, self.right + value)
-        else:
+        if isinstance(value, Expression):
             return super().__add__(value)
+        else:
+            return Sum(self.ebpf, self.left, self.right + value)
 
     __radd__ = __add__
 
     def __sub__(self, value):
-        if isinstance(value, int):
-            return Sum(self.ebpf, self.left, self.right - value)
-        else:
+        if isinstance(value, Expression):
             return super().__sub__(value)
+        else:
+            return Sum(self.ebpf, self.left, self.right - value)
 
 
 class AndExpression(Binary):
@@ -728,18 +734,18 @@ class Register(Expression):
         self.signed = signed
 
     def __add__(self, value):
-        if isinstance(value, int) and self.long:
-            return Sum(self.ebpf, self, value)
-        else:
+        if isinstance(value, Expression) or not self.long:
             return super().__add__(value)
+        else:
+            return Sum(self.ebpf, self, value)
 
     __radd__ = __add__
 
     def __sub__(self, value):
-        if isinstance(value, int) and self.long:
-            return Sum(self.ebpf, self, -value)
-        else:
+        if isinstance(value, Expression) or not self.long:
             return super().__sub__(value)
+        else:
+            return Sum(self.ebpf, self, -value)
 
     @contextmanager
     def calculate(self, dst, long, force=False):
@@ -865,21 +871,21 @@ class MemoryDesc:
                 mask = ((1 << fmt[1]) - 1) << fmt[0]
                 value = (mask & (value << self.fmt[0]) | ~mask & before)
             opcode = Opcode.STX
-        elif isinstance(value, int):
-            ebpf.append(Opcode.ST + bits, self.base_register, 0,
-                        addr, value)
-            return
         elif isinstance(value, IAdd):
             value = value.value
-            if isinstance(value, int):
+            if not isinstance(value, Expression):
                 with ebpf.get_free_register(None) as src:
                     ebpf.r[src] = value
                     ebpf.append(Opcode.XADD + bits, self.base_register,
                                 src, addr, 0)
                 return
             opcode = Opcode.XADD
-        else:
+        elif isinstance(value, Expression):
             opcode = Opcode.STX
+        else:
+            ebpf.append(Opcode.ST + bits, self.base_register, 0,
+                        addr, value)
+            return
         with value.calculate(None, isinstance(fmt, str) and fmt in 'qQ'
                             ) as (src, _):
             ebpf.append(opcode + bits, self.base_register, src, addr, 0)
@@ -922,11 +928,7 @@ class MemoryMap:
             else:
                 dst, _ = exitStack.enter_context(addr.calculate(None, True))
                 offset = 0
-            if isinstance(value, int):
-                self.ebpf.append(Opcode.ST + Memory.fmt_to_opcode[self.fmt],
-                                 dst, 0, offset, value)
-                return
-            elif isinstance(value, IAdd):
+            if isinstance(value, IAdd):
                 value = value.value
                 if isinstance(value, int):
                     with self.ebpf.get_free_register(None) as src:
@@ -936,8 +938,12 @@ class MemoryMap:
                             dst, src, offset, 0)
                     return
                 opcode = Opcode.XADD
-            else:
+            elif isinstance(value, Expression):
                 opcode = Opcode.STX
+            else:
+                self.ebpf.append(Opcode.ST + Memory.fmt_to_opcode[self.fmt],
+                                 dst, 0, offset, value)
+                return
             with value.calculate(None, None) as (src, _):
                 self.ebpf.append(opcode + Memory.fmt_to_opcode[self.fmt],
                                  dst, src, offset, 0)
@@ -1025,13 +1031,11 @@ class RegisterArray:
 
     def __setitem__(self, no, value):
         self.ebpf.owners.add(no)
-        if isinstance(value, int):
-            self.ebpf._load_value(no, value)
-        elif isinstance(value, Expression):
+        if isinstance(value, Expression):
             with value.calculate(no, self.long, True):
                 pass
         else:
-            raise AssembleError("cannot compile")
+            self.ebpf._load_value(no, value)
 
     def __getitem__(self, no):
         return Register(no, self.ebpf, self.long, self.signed)
-- 
GitLab