From c3241b5e1609085844d6bd8f77c02470c373d67f Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Mon, 1 Mar 2021 11:39:47 +0000
Subject: [PATCH] improve handling of longs

---
 ebpfcat/ebpf.py      | 19 +++++++++++--------
 ebpfcat/ebpf_test.py | 25 ++++++++++++++++++++-----
 2 files changed, 31 insertions(+), 13 deletions(-)

diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py
index 99b91ff..1fb6969 100644
--- a/ebpfcat/ebpf.py
+++ b/ebpfcat/ebpf.py
@@ -476,20 +476,23 @@ class Binary(Expression):
             dst = None
         with self.ebpf.get_free_register(dst) as dst:
             with self.left.calculate(dst, long, signed, True) \
-                    as (dst, long, signed):
+                    as (dst, l_long, signed):
                 pass
             if self.operator is Opcode.RSH and signed:  # >>=
                 operator = Opcode.ARSH
             else:
                 operator = self.operator
             if isinstance(self.right, int):
-                self.ebpf.append(operator + (Opcode.LONG if long is None
-                                             else Opcode.LONG * long),
+                self.ebpf.append(operator + Opcode.LONG * long,
                                  dst, 0, 0, self.right)
             else:
-                with self.right.calculate(None, long, None) as (src, long, _):
-                    self.ebpf.append(operator + Opcode.LONG*long + Opcode.REG,
-                                     dst, src, 0, 0)
+                with self.right.calculate(None, long, None) as \
+                        (src, r_long, _):
+                    self.ebpf.append(
+                        operator + Opcode.REG
+                        + Opcode.LONG * ((r_long or l_long)
+                                         if long is None else long),
+                        dst, src, 0, 0)
             if orig_dst is None or orig_dst == dst:
                 yield dst, long, signed
                 return
@@ -649,7 +652,7 @@ class Memory(Expression):
 
     @contextmanager
     def get_address(self, dst, long, signed, force=False):
-        with self.address.calculate(dst, None, None) as (src, _, _):
+        with self.address.calculate(dst, True, None) as (src, _, _):
             yield src, self.bits
 
     def contains(self, no):
@@ -720,7 +723,7 @@ class MemoryMap:
                 offset = addr.right
             else:
                 dst, _, _ = exitStack.enter_context(
-                        addr.calculate(None, None, None))
+                        addr.calculate(None, True, None))
                 offset = 0
             if isinstance(value, int):
                 self.ebpf.append(Opcode.ST + self.bits, dst, 0, offset, value)
diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py
index bf65e6b..f85dbb8 100644
--- a/ebpfcat/ebpf_test.py
+++ b/ebpfcat/ebpf_test.py
@@ -343,7 +343,7 @@ class Tests(TestCase):
 
         self.assertEqual(e.opcodes, [
             Instruction(opcode=191, dst=0, src=1, off=0, imm=0),
-            Instruction(opcode=15, dst=0, src=3, off=0, imm=0),
+            Instruction(opcode=O.ADD+O.REG+O.LONG, dst=0, src=3, off=0, imm=0),
             Instruction(opcode=181, dst=0, src=0, off=2, imm=3),
             Instruction(opcode=183, dst=0, src=0, off=0, imm=5),
             Instruction(opcode=5, dst=0, src=0, off=1, imm=0),
@@ -400,6 +400,21 @@ class Tests(TestCase):
             Instruction(opcode=191, dst=0, src=1, off=0, imm=0),
             Instruction(opcode=95, dst=0, src=2, off=0, imm=0)])
 
+    def test_mixed_binary(self):
+        e = EBPF()
+        e.owners = {0, 1, 2, 3}
+        e.w1 = e.r2 + e.w3
+        e.r1 = e.w2 + e.w3
+        e.w1 = e.w2 + e.w3
+        self.assertEqual(e.opcodes, [
+            Instruction(opcode=O.MOV+O.LONG+O.REG, dst=1, src=2, off=0, imm=0),
+            Instruction(opcode=O.REG+O.ADD, dst=1, src=3, off=0, imm=0),
+            Instruction(opcode=O.MOV+O.REG, dst=1, src=2, off=0, imm=0),
+            Instruction(opcode=O.LONG+O.REG+O.ADD, dst=1, src=3, off=0, imm=0),
+            Instruction(opcode=O.MOV+O.REG, dst=1, src=2, off=0, imm=0),
+            Instruction(opcode=O.REG+O.ADD, dst=1, src=3, off=0, imm=0)])
+
+
     def test_reverse_binary(self):
         e = EBPF()
         e.owners = {0, 1, 2, 3}
@@ -497,16 +512,16 @@ class Tests(TestCase):
             Instruction(opcode=39, dst=0, src=0, off=0, imm=2),
             Instruction(opcode=31, dst=3, src=0, off=0, imm=0),
             Instruction(opcode=191, dst=0, src=3, off=0, imm=0),
-            Instruction(opcode=39, dst=0, src=0, off=0, imm=2),
+            Instruction(opcode=O.MUL, dst=0, src=0, off=0, imm=2),
             Instruction(opcode=107, dst=10, src=0, off=-10, imm=0),
             Instruction(opcode=191, dst=0, src=10, off=0, imm=0),
             Instruction(opcode=15, dst=0, src=3, off=0, imm=0),
             Instruction(opcode=191, dst=2, src=3, off=0, imm=0),
-            Instruction(opcode=39, dst=2, src=0, off=0, imm=2),
+            Instruction(opcode=O.MUL, dst=2, src=0, off=0, imm=2),
             Instruction(opcode=107, dst=0, src=2, off=0, imm=0),
 
             Instruction(opcode=191, dst=5, src=10, off=0, imm=0),
-            Instruction(opcode=15, dst=5, src=3, off=0, imm=0),
+            Instruction(opcode=O.ADD+O.REG+O.LONG, dst=5, src=3, off=0, imm=0),
             Instruction(opcode=105, dst=5, src=5, off=0, imm=0),
 
             Instruction(opcode=191, dst=0, src=1, off=0, imm=0),
@@ -572,7 +587,7 @@ class Tests(TestCase):
             Instruction(opcode=O.LD+O.W, dst=9, src=1, off=0, imm=0),
             Instruction(opcode=O.LD+O.W, dst=0, src=1, off=4, imm=0),
             Instruction(opcode=O.LD+O.W, dst=2, src=1, off=0, imm=0),
-            Instruction(opcode=O.ADD+O.LONG, dst=2, src=0, off=0, imm=100),
+            Instruction(opcode=O.ADD, dst=2, src=0, off=0, imm=100),
             Instruction(opcode=O.REG+O.JLE, dst=0, src=2, off=2, imm=0),
             Instruction(opcode=O.REG+O.LD, dst=3, src=9, off=22, imm=0),
             Instruction(opcode=O.JMP, dst=0, src=0, off=1, imm=0),
-- 
GitLab