From 183a37fb84c6b7790e312e93976a059c41bdd316 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Mon, 28 Dec 2020 14:04:16 +0000
Subject: [PATCH] amidst making & and | work for If

super complicated, pausing.
---
 ebpf.py      | 108 +++++++++++++++++++++++++++++++++++----------------
 ebpf_test.py |   7 ++--
 2 files changed, 78 insertions(+), 37 deletions(-)

diff --git a/ebpf.py b/ebpf.py
index bd219c2..cca2f1f 100644
--- a/ebpf.py
+++ b/ebpf.py
@@ -32,35 +32,33 @@ def comparison(uposop, unegop, sposop=None, snegop=None):
 
 
 class Comparison:
-    def target(self):
-        assert self.ebpf.opcodes[self.origin] is None
-        if isinstance(self.right, int):
-            inst = Instruction(
-                self.opcode, self.dst, 0,
-                len(self.ebpf.opcodes) - self.origin - 1, self.right)
-        else:
-            inst = Instruction(
-                self.opcode + 8, self.dst, self.src,
-                len(self.ebpf.opcodes) - self.origin - 1, 0)
-        self.ebpf.opcodes[self.origin] = inst
-        self.ebpf.owners, self.owners = \
-                self.ebpf.owners & self.owners, self.ebpf.owners
+    def __init__(self, ebpf):
+        self.ebpf = ebpf
+        self.invert = None
+        self.else_origin = None
 
     def __enter__(self):
         return self
 
     def __exit__(self, exc_type, exc, tb):
-        self.target()
+        if self.else_origin is None:
+            self.target()
+            return
+        assert self.ebpf.opcodes[self.else_origin] is None
+        self.ebpf.opcodes[self.else_origin] = Instruction(5, 0, 0, len(self.ebpf.opcodes) - self.else_origin - 1, 0)
+        self.ebpf.owners, self.owners = \
+                self.ebpf.owners & self.owners, self.ebpf.owners
+
         if self.invert is not None:
             olen = len(self.ebpf.opcodes)
             assert self.ebpf.opcodes[self.invert].opcode == 5
             self.ebpf.opcodes[self.invert:self.invert] = \
-                    self.ebpf.opcodes[self.origin+1:]
+                    self.ebpf.opcodes[self.else_origin+1:]
             del self.ebpf.opcodes[olen-1:]
             op, dst, src, off, imm = self.ebpf.opcodes[self.invert - 1]
             self.ebpf.opcodes[self.invert - 1] = \
                     Instruction(op, dst, src,
-                                len(self.ebpf.opcodes) - self.origin + 1, imm)
+                                len(self.ebpf.opcodes) - self.else_origin + 1, imm)
 
     def Else(self):
         op, dst, src, off, imm = self.ebpf.opcodes[self.origin]
@@ -69,20 +67,34 @@ class Comparison:
         else:
             self.ebpf.opcodes[self.origin] = \
                     Instruction(op, dst, src, off+1, imm)
-        self.origin = len(self.ebpf.opcodes)
+        self.else_origin = len(self.ebpf.opcodes)
+        self.ebpf.opcodes.append(None)
+        return self
+
+    def invert_result(self):
+        origin = len(self.ebpf.opcodes)
         self.ebpf.opcodes.append(None)
+        self.target()
+        self.origin = origin
         self.right = self.dst = 0
         self.opcode = 5
-        return self
+
+    def __and__(self, value):
+        return AndOrComparison(self.ebpf, self, value, True)
+
+    def __or__(self, value):
+        return AndOrComparison(self.ebpf, self, value, False)
+
+    def __invert__(self):
+        return InvertComparison(self.ebpf, self)
 
 
 class SimpleComparison(Comparison):
     def __init__(self, ebpf, left, right, opcode):
-        self.ebpf = ebpf
+        super().__init__(ebpf)
         self.left = left
         self.right = right
         self.opcode = opcode
-        self.invert = None
 
     def compare(self, negative):
         self.dst, _, lsigned, lfree = self.left.calculate(None, None, None)
@@ -102,14 +114,52 @@ class SimpleComparison(Comparison):
             self.ebpf.owners.discard(self.src)
         self.owners = self.ebpf.owners.copy()
 
+    def target(self):
+        assert self.ebpf.opcodes[self.origin] is None
+        if isinstance(self.right, int):
+            inst = Instruction(
+                self.opcode, self.dst, 0,
+                len(self.ebpf.opcodes) - self.origin - 1, self.right)
+        else:
+            inst = Instruction(
+                self.opcode + 8, self.dst, self.src,
+                len(self.ebpf.opcodes) - self.origin - 1, 0)
+        self.ebpf.opcodes[self.origin] = inst
+        self.ebpf.owners, self.owners = \
+                self.ebpf.owners & self.owners, self.ebpf.owners
+
+
 class AndOrComparison(Comparison):
+    def __init__(self, ebpf, left, right, is_and):
+        super().__init__(ebpf)
+        self.left = left
+        self.right = right
+        self.is_and = is_and
+        self.targetted = False
+
     def compare(self, negative):
-        self.left.compare(negative)
-        self.right.compare(negative)
+        self.left.compare(self.is_and != negative)
+        self.right.compare(self.is_and != negative)
+        if self.is_and != negative:
+            self.invert_result()
+            self.owners = self.ebpf.owners.copy()
 
     def target(self):
-        self.left.target()
-        self.right.target()
+        if self.targetted:
+            super().target()
+        else:
+            self.left.target()
+            self.right.target()
+            self.targetted = True
+
+
+class InvertComparison(Comparison):
+    def __init__(self, ebpf, value):
+        self.ebpf = ebpf
+        self.value = value
+
+    def compare(self, negative):
+        self.value.compare(not negative)
 
 
 def binary(opcode, symetric=False):
@@ -192,9 +242,6 @@ class Sum(Binary):
 
 
 class AndExpression(Binary, SimpleComparison):
-    __and__ = __rand__ = comparison(0x45, None)
-    __rand__ = __and__ = binary(0x54, True)
-
     def __init__(self, ebpf, left, right):
         Binary.__init__(self, ebpf, left, right, 0x54)
         SimpleComparison.__init__(self, ebpf, left, right, 0x45)
@@ -203,12 +250,7 @@ class AndExpression(Binary, SimpleComparison):
     def compare(self, negative):
         super().compare(False)
         if negative:
-            origin = len(self.ebpf.opcodes)
-            self.ebpf.opcodes.append(None)
-            self.target()
-            self.origin = origin
-            self.right = self.dst = 0
-            self.opcode = 5
+            self.invert_result()
 
 class Register(Expression):
     offset = 0
diff --git a/ebpf_test.py b/ebpf_test.py
index a50dbc6..243f9b5 100644
--- a/ebpf_test.py
+++ b/ebpf_test.py
@@ -394,11 +394,10 @@ class Tests(TestCase):
 class KernelTests(TestCase):
     def test_minimal(self):
         e = EBPF(ProgType.XDP, "GPL")
-        with e.If(e.r1 & 1111111) as cond:
-            e.r0 = 2
-            e.r1 = 4
+        with e.If((e.r1 == 0x111111) & (e.r10 == 0x22222)) as cond:
+            e.r0 = 333333
         with cond.Else():
-            e.r0 = 3
+            e.r0 = 444444
         e.exit()
         print(e.load(log_level=1)[1])
         self.fail()
-- 
GitLab