From 814cf784a628ddca721bec1dfb61b54298923a7a Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Mon, 21 Dec 2020 08:51:21 +0000
Subject: [PATCH] support comparisons

---
 ebpf.py      | 101 ++++++++++++++++++++++++++++++++++++++++++++++--
 ebpf_test.py | 107 ++++++++++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 203 insertions(+), 5 deletions(-)

diff --git a/ebpf.py b/ebpf.py
index 6ecf983..8f3dff9 100644
--- a/ebpf.py
+++ b/ebpf.py
@@ -18,6 +18,52 @@ def augassign(opcode):
     return ret
 
 
+def comparison(uposop, unegop, sposop=None, snegop=None):
+    def ret(self, value):
+        if self.signed and sposop is not None:
+            return Comparison(self.no, value, sposop, snegop)
+        else:
+            return Comparison(self.no, value, uposop, unegop)
+    return ret
+
+
+class Comparison:
+    def __init__(self, dst, src, posop, negop):
+        self.dst = dst
+        self.src = src
+        self.posop = posop
+        self.negop = negop
+
+    def target(self):
+        assert self.ebpf.opcodes[self.origin] is None
+        if isinstance(self.src, int):
+            inst = Instruction(
+                self.opcode, self.dst, 0,
+                len(self.ebpf.opcodes) - self.origin - 1, self.src)
+        elif isinstance(self.src, Register):
+            inst = Instruction(
+                self.opcode + 8, self.dst, self.src.no,
+                len(self.ebpf.opcodes) - self.origin - 1, 0)
+        else:
+            return NotImplemented
+        self.ebpf.opcodes[self.origin] = inst
+
+    def __enter__(self):
+        self.origin = len(self.ebpf.opcodes)
+        self.ebpf.opcodes.append(None)
+        return self
+
+    def __exit__(self, exc_type, exc, tb):
+        self.target()
+
+    def Else(self):
+        op, dst, src, off, imm = self.ebpf.opcodes[self.origin]
+        self.ebpf.opcodes[self.origin] = Instruction(op, dst, src, off+1, imm)
+        self.src = self.dst = 0
+        self.opcode = 5
+        return self
+
+
 class Sum:
     def __init__(self, no, offset):
         self.no = no
@@ -41,10 +87,11 @@ class Sum:
 class Register:
     offset = 0
 
-    def __init__(self, no, ebpf, long):
+    def __init__(self, no, ebpf, long, signed):
         self.no = no
         self.ebpf = ebpf
         self.long = long
+        self.signed = signed
 
     __iadd__ = augassign(4)
     __isub__ = augassign(0x14)
@@ -53,10 +100,19 @@ class Register:
     __ior__ = augassign(0x44)
     __iand__ = augassign(0x54)
     __ilshift__ = augassign(0x64)
-    __irshift__ = augassign(0x74)
     __imod__ = augassign(0x94)
     __ixor__ = augassign(0xa4)
 
+    def __irshift__(self, value):
+        if isinstance(value, int):
+            return Instruction(0x74 + 3 * self.long + 0x50 * self.signed,
+                               self.no, 0, 0, value)
+        elif isinstance(value, Register) and self.long == value.long:
+            return Instruction(0x7c + 3 * self.long + 0x50 * self.signed,
+                               self.no, value.no, 0, 0)
+        else:
+            return NotImplemented
+
     def __add__(self, value):
         if isinstance(value, int) and self.long:
             return Sum(self.no, value)
@@ -71,6 +127,14 @@ class Register:
         else:
             return NotImplemented
 
+    __eq__ = comparison(0x15, 0x55)
+    __gt__ = comparison(0x25, 0xb5, 0x65, 0xd5)
+    __ge__ = comparison(0x35, 0xa5, 0x75, 0xc5)
+    __lt__ = comparison(0xa5, 0x35, 0xc5, 0x75)
+    __le__ = comparison(0xb5, 0x25, 0xd5, 0x65)
+    __ne__ = comparison(0x55, 0x15)
+    __and__ = __rand__ = comparison(0x45, None)
+
 
 class Memory:
     def __init__(self, ebpf, bits):
@@ -93,15 +157,16 @@ class Memory:
 
 
 class RegisterDesc:
-    def __init__(self, no, long):
+    def __init__(self, no, long, signed=False):
         self.no = no
         self.long = long
+        self.signed = signed
 
     def __get__(self, instance, owner=None):
         if instance is None:
             return self
         else:
-            return Register(self.no, instance, self.long)
+            return Register(self.no, instance, self.long, self.signed)
 
     def __set__(self, instance, value):
         if isinstance(value, int):
@@ -141,6 +206,31 @@ class EBPF:
         return prog_load(self.prog_type, self.assemble(), self.license,
                          log_level, log_size, self.kern_version)
 
+    def jumpIf(self, comp):
+        comp.origin = len(self.opcodes)
+        comp.ebpf = self
+        comp.opcode = comp.posop
+        self.opcodes.append(None)
+        return comp
+
+    def jump(self):
+        comp = Comparison(0, 0, None, None)
+        comp.origin = len(self.opcodes)
+        comp.ebpf = self
+        comp.opcode = 5
+        self.opcodes.append(None)
+        return comp
+
+    def If(self, comp):
+        comp.opcode = comp.negop
+        comp.ebpf = self
+        return comp
+
+    def isZero(self, comp):
+        comp.opcode = comp.negop
+        comp.ebpf = self
+        return comp
+
     def exit(self):
         self.append(0x95,0, 0, 0, 0)
 
@@ -148,5 +238,8 @@ class EBPF:
 for i in range(10):
     setattr(EBPF, f"r{i}", RegisterDesc(i, True))
 
+for i in range(10):
+    setattr(EBPF, f"sr{i}", RegisterDesc(i, True, True))
+
 for i in range(10):
     setattr(EBPF, f"s{i}", RegisterDesc(i, False))
diff --git a/ebpf_test.py b/ebpf_test.py
index be9a6e7..9f236af 100644
--- a/ebpf_test.py
+++ b/ebpf_test.py
@@ -107,10 +107,115 @@ class Tests(TestCase):
              Instruction(opcode=97, dst=4, src=8, off=7, imm=0),
              Instruction(opcode=121, dst=5, src=3, off=-7, imm=0)])
 
+
+    def test_jump(self):
+        e = EBPF()
+        target = e.jump()
+        e.r0 = 1
+        target.target()
+        t1 = e.jumpIf(e.r5 > 3)
+        t2 = e.jumpIf(e.r1 > e.r2)
+        t3 = e.jumpIf(e.r7 >= 2)
+        t4 = e.jumpIf(e.r4 >= e.r3)
+        e.r0 = 1
+        t1.target()
+        t2.target()
+        t3.target()
+        t4.target()
+        t1 = e.jumpIf(e.r5 < 3)
+        t2 = e.jumpIf(e.r1 < e.r2)
+        t3 = e.jumpIf(e.r7 <= 2)
+        t4 = e.jumpIf(e.r4 <= e.r3)
+        e.r0 = 1
+        t1.target()
+        t2.target()
+        t3.target()
+        t4.target()
+        t1 = e.jumpIf(e.sr5 > 3)
+        t2 = e.jumpIf(e.sr1 > e.sr2)
+        t3 = e.jumpIf(e.sr7 >= 2)
+        t4 = e.jumpIf(e.sr4 >= e.sr3)
+        e.r0 = 1
+        t1.target()
+        t2.target()
+        t3.target()
+        t4.target()
+        t1 = e.jumpIf(e.sr5 < 3)
+        t2 = e.jumpIf(e.sr1 < e.sr2)
+        t3 = e.jumpIf(e.sr7 <= 2)
+        t4 = e.jumpIf(e.sr4 <= e.sr3)
+        e.r0 = 1
+        t1.target()
+        t2.target()
+        t3.target()
+        t4.target()
+        t1 = e.jumpIf(e.sr5 == 3)
+        t2 = e.jumpIf(e.sr1 == e.sr2)
+        t3 = e.jumpIf(e.sr7 != 2)
+        t4 = e.jumpIf(e.sr4 != e.sr3)
+        e.r0 = 1
+        t1.target()
+        t2.target()
+        t3.target()
+        t4.target()
+        t1 = e.jumpIf(e.sr5 & 3)
+        t2 = e.jumpIf(e.sr1 & e.sr2)
+        e.r0 = 1
+        t1.target()
+        t2.target()
+        self.assertEqual(e.opcodes,
+            [Instruction(opcode=5, dst=0, src=0, off=1, imm=0),
+             Instruction(opcode=0xb7, dst=0, src=0, off=0, imm=1),
+             Instruction(opcode=0x25, dst=5, src=0, off=4, imm=3),
+             Instruction(opcode=0x2d, dst=1, src=2, off=3, imm=0),
+             Instruction(opcode=0x35, dst=7, src=0, off=2, imm=2),
+             Instruction(opcode=0x3d, dst=4, src=3, off=1, imm=0),
+             Instruction(opcode=0xb7, dst=0, src=0, off=0, imm=1),
+             Instruction(opcode=0xa5, dst=5, src=0, off=4, imm=3),
+             Instruction(opcode=0xad, dst=1, src=2, off=3, imm=0),
+             Instruction(opcode=0xb5, dst=7, src=0, off=2, imm=2),
+             Instruction(opcode=0xbd, dst=4, src=3, off=1, imm=0),
+             Instruction(opcode=0xb7, dst=0, src=0, off=0, imm=1),
+             Instruction(opcode=0x65, dst=5, src=0, off=4, imm=3),
+             Instruction(opcode=0x6d, dst=1, src=2, off=3, imm=0),
+             Instruction(opcode=0x75, dst=7, src=0, off=2, imm=2),
+             Instruction(opcode=0x7d, dst=4, src=3, off=1, imm=0),
+             Instruction(opcode=0xb7, dst=0, src=0, off=0, imm=1),
+             Instruction(opcode=0xc5, dst=5, src=0, off=4, imm=3),
+             Instruction(opcode=0xcd, dst=1, src=2, off=3, imm=0),
+             Instruction(opcode=0xd5, dst=7, src=0, off=2, imm=2),
+             Instruction(opcode=0xdd, dst=4, src=3, off=1, imm=0),
+             Instruction(opcode=0xb7, dst=0, src=0, off=0, imm=1),
+             Instruction(opcode=0x15, dst=5, src=0, off=4, imm=3),
+             Instruction(opcode=0x1d, dst=1, src=2, off=3, imm=0),
+             Instruction(opcode=0x55, dst=7, src=0, off=2, imm=2),
+             Instruction(opcode=0x5d, dst=4, src=3, off=1, imm=0),
+             Instruction(opcode=0xb7, dst=0, src=0, off=0, imm=1),
+             Instruction(opcode=0x45, dst=5, src=0, off=2, imm=3),
+             Instruction(opcode=0x4d, dst=1, src=2, off=1, imm=0),
+             Instruction(opcode=0xb7, dst=0, src=0, off=0, imm=1)])
+
+    def test_with(self):
+        e = EBPF()
+        with e.If(e.r2 > 3) as cond:
+            e.r2 = 5
+        with cond.Else():
+            e.r6 = 7
+        self.assertEqual(e.opcodes,
+            [Instruction(opcode=0xb5, dst=2, src=0, off=2, imm=3),
+             Instruction(opcode=0xb7, dst=2, src=0, off=0, imm=5),
+             Instruction(opcode=0x5, dst=0, src=0, off=1, imm=0),
+             Instruction(opcode=0xb7, dst=6, src=0, off=0, imm=7)])
+
+
 class KernelTests(TestCase):
     def test_minimal(self):
         e = EBPF(ProgType.XDP, "GPL")
-        e.r3 = e.m16[e.r5 - 7]
+        e.r6 = 7
+        target = e.jumpIf(e.r1 > 3)
+        e.r1 = 3
+        target.target()
+        e.r0 = 0
         e.exit()
         self.assertEqual(e.load(log_level=1), "")
 
-- 
GitLab