From 08c865580b4eaa3cafea9fb2e010ab0f42d22f97 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@gmail.com>
Date: Tue, 7 Feb 2023 23:22:21 +0000
Subject: [PATCH] move bit access to ebpf proper

this is not actually ethercat specific
---
 ebpfcat/ebpf.py      | 105 ++++++++++++++++++++++++++++++++++---------
 ebpfcat/ebpf_test.py |  47 +++++++++++++++++++
 ebpfcat/ebpfcat.py   |  26 +----------
 3 files changed, 133 insertions(+), 45 deletions(-)

diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py
index 7144f58..15dea9d 100644
--- a/ebpfcat/ebpf.py
+++ b/ebpfcat/ebpf.py
@@ -266,10 +266,7 @@ class AssembleError(Exception):
     pass
 
 
-def comparison(uposop, unegop, sposop=None, snegop=None):
-    if sposop is None:
-        sposop = uposop
-        snegop = unegop
+def comparison(uposop, unegop, sposop, snegop):
     def ret(self, value):
         return SimpleComparison(self.ebpf, self, value,
                                 (uposop, unegop, sposop, snegop))
@@ -385,7 +382,6 @@ class AndOrComparison(Comparison):
         self.left = left
         self.right = right
         self.is_and = is_and
-        self.targetted = False
 
     def compare(self, negative):
         self.negative = negative
@@ -404,11 +400,15 @@ class AndOrComparison(Comparison):
 
 class InvertComparison(Comparison):
     def __init__(self, ebpf, value):
-        self.ebpf = ebpf
+        super().__init__(ebpf)
         self.value = value
 
     def compare(self, negative):
         self.value.compare(not negative)
+        self.owners = self.value.owners
+
+    def target(self, retarget=False):
+        self.value.target(retarget)
 
 
 def binary(opcode):
@@ -439,16 +439,22 @@ class Expression:
     __rmod__ = rbinary(Opcode.MOD)
     __rxor__ = __xor__ = binary(Opcode.XOR)
 
-    __eq__ = comparison(Opcode.JEQ, Opcode.JNE)
     __gt__ = comparison(Opcode.JGT, Opcode.JLE, Opcode.JSGT, Opcode.JSLE)
     __ge__ = comparison(Opcode.JGE, Opcode.JLT, Opcode.JSGE, Opcode.JSLT)
     __lt__ = comparison(Opcode.JLT, Opcode.JGE, Opcode.JSLT, Opcode.JSGE)
     __le__ = comparison(Opcode.JLE, Opcode.JGT, Opcode.JSLE, Opcode.JSGT)
-    __ne__ = comparison(Opcode.JNE, Opcode.JEQ)
 
     def __and__(self, value):
         return AndExpression(self.ebpf, self, value)
 
+    def __ne__(self, value):
+        return SimpleComparison(
+            self.ebpf, self, value,
+            (Opcode.JNE, Opcode.JEQ, Opcode.JNE, Opcode.JEQ))
+
+    def __eq__(self, value):
+        return ~(self != value)
+
     __rand__ = __and__
 
     def __neg__(self):
@@ -632,8 +638,20 @@ class Sum(Binary):
             return super().__sub__(value)
 
 
-class AndExpression(SimpleComparison, Binary):
-    """The & operator may also be used as a comparison"""
+class AndExpression(Binary):
+    # there is a special comparison with & instruction
+    def __init__(self, ebpf, left, right):
+        super().__init__(ebpf, left, right, Opcode.AND)
+
+    def __ne__(self, value):
+        if isinstance(value, int) and value == 0:
+            return AndComparison(self.ebpf, self.left, self.right)
+        return super().__ne__(value)
+
+
+class AndComparison(SimpleComparison):
+    # there is a special comparison with & instruction
+    # it is the only one which has not inversion
     def __init__(self, ebpf, left, right):
         Binary.__init__(self, ebpf, left, right, Opcode.AND)
         SimpleComparison.__init__(self, ebpf, left, right, Opcode.JSET)
@@ -743,13 +761,21 @@ class Memory(Expression):
 
     @contextmanager
     def calculate(self, dst, long, signed, force=False):
-        if isinstance(self.address, Sum):
-            with self.ebpf.get_free_register(dst) as dst:
-                self.ebpf.append(Opcode.LD + self.fmt_to_opcode[self.fmt], dst,
-                                 self.address.left.no, self.address.right, 0)
-                yield dst, self.fmt in "QqA", self.fmt.islower()
-        else:
-            with super().calculate(dst, long, signed, force) as (dst, _, _):
+        with ExitStack() as exitStack:
+            if isinstance(self.address, Sum):
+                dst = exitStack.enter_context(self.ebpf.get_free_register(dst))
+                self.ebpf.append(
+                    Opcode.LD + self.fmt_to_opcode.get(self.fmt, Opcode.B),
+                    dst, self.address.left.no, self.address.right, 0)
+            else:
+                dst, _, _ = exitStack.enter_context(
+                    super().calculate(dst, long, signed, force))
+            if isinstance(self.fmt, tuple):
+                self.ebpf.r[dst] &= ((1 << self.fmt[1]) - 1) << self.fmt[0]
+                if self.fmt[0] > 0:
+                    self.ebpf.r[dst] >>= self.fmt[0]
+                yield dst, "B", False
+            else:
                 yield dst, self.fmt in "QqA", self.fmt.islower()
 
     @contextmanager
@@ -764,6 +790,18 @@ class Memory(Expression):
     def signed(self):
         return isinstance(self.fmt, str) and self.fmt.islower()
 
+    def __invert__(self):
+        if not isinstance(self.fmt, tuple) or self.fmt[1] != 1:
+            return NotImplemented
+        return self == 0
+
+    def __ne__(self, value):
+        if isinstance(self.fmt, tuple) and isinstance(value, int) \
+                and value == 0:
+            mask = ((1 << self.fmt[1]) - 1) << self.fmt[0]
+            return Memory(self.ebpf, "B", self.address) & mask != 0
+        return super().__ne__(value)
+
 
 class MemoryDesc:
     """A base class used by descriptors for memory
@@ -782,8 +820,26 @@ class MemoryDesc:
     def __set__(self, instance, value):
         ebpf = instance.ebpf
         fmt, addr = self.fmt_addr(instance)
-        bits = Memory.fmt_to_opcode[fmt]
-        if isinstance(value, int):
+        bits = Memory.fmt_to_opcode.get(fmt, Opcode.B)
+        if isinstance(fmt, tuple):
+            before = Memory(ebpf, "B", ebpf.r[self.base_register] + addr)
+            if fmt[1] == 1:
+                try:
+                    if value:
+                        value = before | (1 << fmt[0])
+                    else:
+                        value = before & ~(1 << fmt[0])
+                except AssembleError:
+                    with ebpf.wtmp:
+                        with value as cond:
+                            ebpf.wtmp = before | (1 << fmt[0])
+                        with cond.Else():
+                            ebpf.wtmp = before & ~(1 << fmt[0])
+            else:
+                mask = ((1 << fmt[1]) - 1) << fmt[0]
+                value = (mask & (value << self.fmt[0]) | ~mask & before)
+            opcode = Opcode.STX
+        elif isinstance(value, int):
             ebpf.append(Opcode.ST + bits, self.base_register, 0,
                         addr, value)
             return
@@ -798,7 +854,9 @@ class MemoryDesc:
             opcode = Opcode.XADD
         else:
             opcode = Opcode.STX
-        with value.calculate(None, fmt in 'qQ', fmt.islower()) as (src, _, _):
+        with value.calculate(None, isinstance(fmt, str) and fmt in 'qQ',
+                             isinstance(fmt, str) and fmt.islower()
+                            ) as (src, _, _):
             ebpf.append(opcode + bits, self.base_register, src, addr, 0)
 
 
@@ -810,7 +868,10 @@ class LocalVar(MemoryDesc):
         self.fmt = fmt
 
     def __set_name__(self, owner, name):
-        size = calcsize(self.fmt)
+        if isinstance(self.fmt, str):
+            size = calcsize(self.fmt)
+        else:  # this is to support bit addressing, mostly for testing
+            size = 1
         owner.stack -= size
         owner.stack &= -size
         self.relative_addr = owner.stack
@@ -1064,6 +1125,8 @@ class EBPF:
 
     def jumpIf(self, comp):
         """jump if `comp` is true to a later defined `target`"""
+        if isinstance(comp, Expression):
+            comp = comp != 0
         comp.compare(False)
         return comp
 
diff --git a/ebpfcat/ebpf_test.py b/ebpfcat/ebpf_test.py
index c6ca17d..4646af4 100644
--- a/ebpfcat/ebpf_test.py
+++ b/ebpfcat/ebpf_test.py
@@ -174,6 +174,53 @@ class Tests(TestCase):
             Instruction(opcode=O.REG+O.STX, dst=10, src=0, off=-4, imm=0),
             Instruction(opcode=O.DW+O.STX, dst=10, src=1, off=-16, imm=0)])
 
+    def test_local_bits(self):
+        class Local(EBPF):
+            a = LocalVar((5, 1))
+            b = LocalVar((3, 4))
+
+        e = Local(ProgType.XDP, "GPL")
+
+        with e.a:
+            e.a = 1
+
+        e.b = e.a
+
+        with ~e.a:
+            e.b = 3
+
+        with e.b:
+            e.a = 0
+
+        self.assertEqual(e.opcodes, [
+           Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0),
+            Instruction(opcode=O.JSET, dst=0, src=0, off=1, imm=32),
+            Instruction(opcode=O.JMP, dst=0, src=0, off=3, imm=0),
+            Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0),
+            Instruction(opcode=O.OR, dst=0, src=0, off=0, imm=32),
+            Instruction(opcode=O.B+O.STX, dst=10, src=0, off=-1, imm=0),
+            Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0),
+            Instruction(opcode=O.AND+O.LONG, dst=0, src=0, off=0, imm=32),
+            Instruction(opcode=O.RSH+O.LONG, dst=0, src=0, off=0, imm=5),
+            Instruction(opcode=O.LSH, dst=0, src=0, off=0, imm=3),
+            Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=120),
+            Instruction(opcode=O.LD+O.B, dst=2, src=10, off=-2, imm=0),
+            Instruction(opcode=O.AND, dst=2, src=0, off=0, imm=-121),
+            Instruction(opcode=O.REG+O.OR, dst=0, src=2, off=0, imm=0),
+            Instruction(opcode=O.B+O.STX, dst=10, src=0, off=-2, imm=0),
+            Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0),
+            Instruction(opcode=O.JSET, dst=0, src=0, off=4, imm=32),
+            Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-2, imm=0),
+            Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=-121),
+            Instruction(opcode=O.OR, dst=0, src=0, off=0, imm=24),
+            Instruction(opcode=O.B+O.STX, dst=10, src=0, off=-2, imm=0),
+            Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-2, imm=0),
+            Instruction(opcode=O.JSET, dst=0, src=0, off=1, imm=120),
+            Instruction(opcode=O.JMP, dst=0, src=0, off=3, imm=0),
+            Instruction(opcode=O.LD+O.B, dst=0, src=10, off=-1, imm=0),
+            Instruction(opcode=O.AND, dst=0, src=0, off=0, imm=-33),
+            Instruction(opcode=O.B+O.STX, dst=10, src=0, off=-1, imm=0)])
+
     def test_local_subprog(self):
         class Local(EBPF):
             a = LocalVar('I')
diff --git a/ebpfcat/ebpfcat.py b/ebpfcat/ebpfcat.py
index ae42165..f0f0b6c 100644
--- a/ebpfcat/ebpfcat.py
+++ b/ebpfcat/ebpfcat.py
@@ -93,25 +93,6 @@ class PacketVar(MemoryDesc):
 
     def set(self, device, value):
         if device.sync_group.current_data is None:
-            if isinstance(self.size, int):
-                try:
-                    bool(value)
-                except RuntimeError:
-                    e = device.sync_group
-                    with e.wtmp:
-                        e.wtmp = super().__get__(device, None)
-                        with value as cond:
-                            e.wtmp |= 1 << self.size
-                        with cond.Else():
-                            e.wtmp &= ~(1 << self.size)
-                        super().__set__(device, e.wtmp)
-                    return
-                else:
-                    old = super().__get__(device, None)
-                    if value:
-                        value = old | (1 << self.size)
-                    else:
-                        value = old & ~(1 << self.size)
             super().__set__(device, value)
         else:
             data = device.sync_group.current_data
@@ -126,10 +107,7 @@ class PacketVar(MemoryDesc):
 
     def get(self, device):
         if device.sync_group.current_data is None:
-            if isinstance(self.size, int):
-                return super().__get__(device, None) & (1 << self.size)
-            else:
-                return super().__get__(device, None)
+            return super().__get__(device, None)
         else:
             data = device.sync_group.current_data
             start = self._start(device)
@@ -143,7 +121,7 @@ class PacketVar(MemoryDesc):
                + self.position
 
     def fmt_addr(self, device):
-        return ("B" if isinstance(self.size, int) else self.size,
+        return ((self.size, 1) if isinstance(self.size, int) else self.size,
                 self._start(device) + Packet.ETHERNET_HEADER)
 
 
-- 
GitLab