From e16f4697c35b36371e7e660df09f24644a669c95 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Mon, 28 Dec 2020 20:09:13 +0000
Subject: [PATCH] add reverse binary and negation

---
 ebpf.py      | 79 ++++++++++++++++++++++++++++++++++++++++++++--------
 ebpf_test.py | 29 ++++++++++++++++++-
 2 files changed, 96 insertions(+), 12 deletions(-)

diff --git a/ebpf.py b/ebpf.py
index ed0b5cc..4307feb 100644
--- a/ebpf.py
+++ b/ebpf.py
@@ -242,22 +242,32 @@ class InvertComparison(Comparison):
         self.value.compare(not negative)
 
 
-def binary(opcode, symetric=False):
+def binary(opcode):
     def ret(self, value):
         return Binary(self.ebpf, self, value, opcode)
     return ret
 
+def rbinary(opcode):
+    def ret(self, value):
+        return ReverseBinary(self.ebpf, value, self, opcode)
+    return ret
+
 
 class Expression:
-    __radd__ = __add__ = binary(Opcode.ADD, True)
+    __radd__ = __add__ = binary(Opcode.ADD)
     __sub__ = binary(Opcode.SUB)
-    __rmul__ = __mul__ = binary(Opcode.MUL, True)
+    __rsub__ = rbinary(Opcode.SUB)
+    __rmul__ = __mul__ = binary(Opcode.MUL)
     __truediv__ = binary(Opcode.DIV)
-    __ror__ = __or__ = binary(Opcode.OR, True)
+    __rtruediv__ = rbinary(Opcode.DIV)
+    __ror__ = __or__ = binary(Opcode.OR)
     __lshift__ = binary(Opcode.LSH)
+    __rlshift__ = rbinary(Opcode.LSH)
     __rshift__ = binary(Opcode.RSH)
+    __rrshift__ = rbinary(Opcode.RSH)
     __mod__ = binary(Opcode.MOD)
-    __rxor__ = __xor__ = binary(Opcode.XOR, True)
+    __rmod__ = rbinary(Opcode.MOD)
+    __rxor__ = __xor__ = binary(Opcode.XOR)
 
     __eq__ = comparison(Opcode.JEQ, Opcode.JNE)
     __gt__ = comparison(Opcode.JGT, Opcode.JLE, Opcode.JSGT, Opcode.JSLE)
@@ -271,6 +281,9 @@ class Expression:
 
     __rand__ = __and__
 
+    def __neg__(self):
+        return Negate(self.ebpf, self)
+
 
 class Binary(Expression):
     def __init__(self, ebpf, left, right, operator):
@@ -316,6 +329,49 @@ class Binary(Expression):
                                           and self.right.contains(no))
 
 
+class ReverseBinary(Expression):
+    def __init__(self, ebpf, left, right, operator):
+        self.ebpf = ebpf
+        self.left = left
+        self.right = right
+        self.operator = operator
+
+    def calculate(self, dst, long, signed, force=False):
+        if dst is None:
+            dst = self.ebpf.get_free_register()
+            self.ebpf.owners.add(dst)
+            free = True
+        else:
+            free = False
+        self.ebpf._load_value(dst, self.left)
+        if self.operator is Opcode.RSH and self.left < 0:  # >>=
+            operator = Opcode.ARSH
+        else:
+            operator = self.operator
+
+        src, long, _, rfree = self.right.calculate(None, long, None)
+        self.ebpf.append(operator + Opcode.LONG * long + Opcode.REG,
+                         dst, src, 0, 0)
+        return dst, long, signed, free
+
+    def contains(self, no):
+        return self.right.contains(no)
+
+
+class Negate(Expression):
+    def __init__(self, ebpf, arg):
+        self.ebpf = ebpf
+        self.arg = arg
+
+    def calculate(self, dst, long, signed, force=False):
+        dst, long, signed, free = self.arg.calculate(dst, long, signed, force)
+        self.ebpf.append(Opcode.NEG + Opcode.LONG * long, dst, 0, 0, 0)
+        return dst, long, signed, free
+
+    def contains(self, no):
+        return self.arg.contains(no)
+
+
 class Sum(Binary):
     def __init__(self, ebpf, left, right):
         super().__init__(ebpf, left, right, Opcode.ADD)
@@ -496,12 +552,7 @@ class RegisterDesc:
     def __set__(self, instance, value):
         instance.owners.add(self.no)
         if isinstance(value, int):
-            if -0x80000000 <= value < 0x80000000:
-                instance.append(Opcode.MOV + Opcode.LONG * self.long,
-                                self.no, 0, 0, value)
-            else:
-                instance.append(Opcode.DW, self.no, 0, 0, value & 0xffffffff)
-                instance.append(Opcode.W, 0, 0, 0, value >> 32)
+            instance._load_value(self.no, value)
         elif isinstance(value, Expression):
             value.calculate(self.no, self.long, self.signed, True)
         elif isinstance(value, Instruction):
@@ -571,6 +622,12 @@ class EBPF:
                 return i
         raise AssembleError("not enough registers")
 
+    def _load_value(self, no, value):
+        if -0x80000000 <= value < 0x80000000:
+            self.append(Opcode.MOV + Opcode.LONG, no, 0, 0, value)
+        else:
+            self.append(Opcode.DW, no, 0, 0, value & 0xffffffff)
+            self.append(Opcode.W, 0, 0, 0, value >> 32)
 
 for i in range(11):
     setattr(EBPF, f"r{i}", RegisterDesc(i, True))
diff --git a/ebpf_test.py b/ebpf_test.py
index 015e4fb..58c51d1 100644
--- a/ebpf_test.py
+++ b/ebpf_test.py
@@ -48,7 +48,7 @@ class Tests(TestCase):
         e.w2 += 3
         e.w5 += e.w6
         self.assertEqual(e.opcodes, 
-            [Instruction(0xb4, 3, 0, 0, 7),
+            [Instruction(O.MOV+O.LONG, 3, 0, 0, 7),
              Instruction(0xbc, 4, 1, 0, 0),
              Instruction(opcode=4, dst=2, src=0, off=0, imm=3),
              Instruction(opcode=0xc, dst=5, src=6, off=0, imm=0)])
@@ -323,6 +323,33 @@ class Tests(TestCase):
             Instruction(opcode=191, dst=0, src=1, off=0, imm=0),
             Instruction(opcode=95, dst=0, src=2, off=0, imm=0)])
 
+    def test_reverse_binary(self):
+        e = EBPF()
+        e.owners = {0, 1, 2, 3}
+        e.r3 = 7 / e.r2
+        e.r3 = 7 << e.r2
+        e.r3 = 7 % e.r2
+        e.r3 = 7 >> e.r2
+        e.r3 = -7 >> e.r2
+        self.assertEqual(e.opcodes, [
+            Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=7),
+            Instruction(opcode=O.REG+O.LONG+O.DIV, dst=3, src=2, off=0, imm=0),
+            Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=7),
+            Instruction(opcode=O.LSH+O.REG+O.LONG, dst=3, src=2, off=0, imm=0),
+            Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=7),
+            Instruction(opcode=O.REG+O.MOD+O.LONG, dst=3, src=2, off=0, imm=0),
+            Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=7),
+            Instruction(opcode=O.REG+O.RSH+O.LONG, dst=3, src=2, off=0, imm=0),
+            Instruction(opcode=O.MOV+O.LONG, dst=3, src=0, off=0, imm=-7),
+            Instruction(opcode=O.REG+O.LONG+O.ARSH, dst=3, src=2, off=0, imm=0)
+            ])
+
+    def test_reverse_binary(self):
+        e = EBPF()
+        e.r7 = -e.r1
+        self.assertEqual(e.opcodes, [
+            Instruction(opcode=O.LONG+O.REG+O.MOV, dst=7, src=1, off=0, imm=0),
+            Instruction(opcode=O.LONG+O.NEG, dst=7, src=0, off=0, imm=0)])
 
     def test_jump_data(self):
         e = EBPF()
-- 
GitLab