From f00e3c6dd544a7f0e08e3efb67055c901e88427b Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Wed, 3 Mar 2021 09:53:26 +0000
Subject: [PATCH] introduce a watchdog

---
 ebpfcat/arraymap.py | 19 ++++++++++------
 ebpfcat/devices.py  | 12 ++++++++++
 ebpfcat/ebpf.py     | 46 +++++++++++++++++++------------------
 ebpfcat/ebpfcat.py  | 55 +++++++++++++++++++++------------------------
 ebpfcat/hashmap.py  |  2 +-
 ebpfcat/xdp.py      |  2 +-
 6 files changed, 76 insertions(+), 60 deletions(-)

diff --git a/ebpfcat/arraymap.py b/ebpfcat/arraymap.py
index 1cf4ae8..9552eb3 100644
--- a/ebpfcat/arraymap.py
+++ b/ebpfcat/arraymap.py
@@ -1,7 +1,7 @@
 from itertools import chain
 from struct import pack_into, unpack_from, calcsize
 
-from .ebpf import FuncId, Map, Memory, MemoryDesc, Opcode
+from .ebpf import FuncId, Map, MemoryDesc
 from .bpf import create_map, lookup_elem, MapType, update_elem
 
 
@@ -22,18 +22,24 @@ class ArrayGlobalVarDesc(MemoryDesc):
     def __get__(self, instance, owner):
         if instance is None:
             return self
-        fmt, addr = self.fmt_addr(instance)
         if instance.ebpf.loaded:
+            fmt, addr = self.fmt_addr(instance)
             data = instance.ebpf.__dict__[self.map.name].data
-            return unpack_from(fmt, data, addr)[0]
+            ret = unpack_from(fmt, data, addr)
+            if len(ret) == 1:
+                return ret[0]
+            else:
+                return ret
         else:
             return super().__get__(instance, owner)
 
     def __set__(self, instance, value):
-        fmt, addr = self.fmt_addr(instance)
         if instance.ebpf.loaded:
+            fmt, addr = self.fmt_addr(instance)
+            if not isinstance(value, tuple):
+                value = value,
             pack_into(fmt, instance.ebpf.__dict__[self.map.name].data,
-                      addr, value)
+                      addr, *value)
         else:
             super().__set__(instance, value)
 
@@ -85,7 +91,6 @@ class ArrayMap(Map):
         else:
             return write_size, position
 
-
     def __set_name__(self, owner, name):
         self.name = name
 
@@ -97,7 +102,7 @@ class ArrayMap(Map):
         fd = create_map(MapType.ARRAY, 4, size, 1)
         setattr(ebpf, self.name, ArrayMapAccess(fd, write_size, size))
         with ebpf.save_registers(list(range(6))), ebpf.get_stack(4) as stack:
-            ebpf.append(Opcode.ST, 10, 0, stack, 0)
+            ebpf.mI[ebpf.r10 + stack] = 0
             ebpf.r1 = ebpf.get_fd(fd)
             ebpf.r2 = ebpf.r10 + stack
             ebpf.call(FuncId.map_lookup_elem)
diff --git a/ebpfcat/devices.py b/ebpfcat/devices.py
index 047610c..55e4288 100644
--- a/ebpfcat/devices.py
+++ b/ebpfcat/devices.py
@@ -184,3 +184,15 @@ class Dummy(Device):
 
     def program(self):
         pass
+
+
+class RandomDropper(Device):
+    rate = DeviceVar("I", write=True)
+
+    def program(self):
+        from .xdp import XDPExitCode
+        with self.ebpf.tmp:
+            self.ebpf.tmp = ktime(self.ebpf)
+            self.ebpf.tmp = self.ebpf.tmp * 0xcf019d85 + 1
+            with self.ebpf.tmp & 0xffff < self.rate:
+                self.ebpf.exit(XDPExitCode.DROP)
diff --git a/ebpfcat/ebpf.py b/ebpfcat/ebpf.py
index 7d3df0c..d92041d 100644
--- a/ebpfcat/ebpf.py
+++ b/ebpfcat/ebpf.py
@@ -1,6 +1,6 @@
 from collections import namedtuple
 from contextlib import contextmanager, ExitStack
-from struct import pack, unpack
+from struct import pack, unpack, calcsize
 from enum import Enum
 
 from . import bpf
@@ -429,8 +429,9 @@ class Expression:
     @contextmanager
     def calculate(self, dst, long, signed, force=False):
         with self.ebpf.get_free_register(dst) as dst:
-            with self.get_address(dst, long, signed) as (src, bits):
-                self.ebpf.append(Opcode.LD + bits, dst, src, 0, 0)
+            with self.get_address(dst, long, signed) as (src, fmt):
+                self.ebpf.append(Opcode.LD + Memory.fmt_to_opcode[fmt],
+                                 dst, src, 0, 0)
                 yield dst, long, self.signed
 
     @contextmanager
@@ -636,12 +637,10 @@ class Memory(Expression):
     bits_to_opcode = {32: Opcode.W, 16: Opcode.H, 8: Opcode.B, 64: Opcode.DW}
     fmt_to_opcode = {'I': Opcode.W, 'H': Opcode.H, 'B': Opcode.B, 'Q': Opcode.DW,
                      'i': Opcode.W, 'h': Opcode.H, 'b': Opcode.B, 'q': Opcode.DW}
-    fmt_to_size = {'I': 4, 'H': 2, 'B': 1, 'Q': 8,
-                   'i': 4, 'h': 2, 'b': 1, 'q': 8}
 
-    def __init__(self, ebpf, bits, address, signed=False, long=False):
+    def __init__(self, ebpf, fmt, address, signed=False, long=False):
         self.ebpf = ebpf
-        self.bits = bits
+        self.fmt = fmt
         self.address = address
         self.signed = signed
         self.long = long
@@ -656,7 +655,7 @@ class Memory(Expression):
     def calculate(self, dst, long, signed, force=False):
         if isinstance(self.address, Sum):
             with self.ebpf.get_free_register(dst) as dst:
-                self.ebpf.append(Opcode.LD + self.bits, dst,
+                self.ebpf.append(Opcode.LD + self.fmt_to_opcode[self.fmt], dst,
                                  self.address.left.no, self.address.right, 0)
                 yield dst, self.long, self.signed
         else:
@@ -666,7 +665,7 @@ class Memory(Expression):
     @contextmanager
     def get_address(self, dst, long, signed, force=False):
         with self.address.calculate(dst, True, None) as (src, _, _):
-            yield src, self.bits
+            yield src, self.fmt
 
     def contains(self, no):
         return self.address.contains(no)
@@ -677,7 +676,7 @@ class MemoryDesc:
         if instance is None:
             return self
         fmt, addr = self.fmt_addr(instance)
-        return Memory(instance.ebpf, Memory.fmt_to_opcode[fmt],
+        return Memory(instance.ebpf, fmt,
                       instance.ebpf.r[self.base_register] + addr,
                       fmt.islower())
 
@@ -711,7 +710,7 @@ class LocalVar(MemoryDesc):
         self.fmt = fmt
 
     def __set_name__(self, owner, name):
-        size = Memory.fmt_to_size[self.fmt]
+        size = calcsize(self.fmt)
         owner.stack -= size
         owner.stack &= -size
         self.relative_addr = owner.stack
@@ -725,9 +724,9 @@ class LocalVar(MemoryDesc):
 
 
 class MemoryMap:
-    def __init__(self, ebpf, bits, signed=False, long=False):
+    def __init__(self, ebpf, fmt, signed=False, long=False):
         self.ebpf = ebpf
-        self.bits = bits
+        self.fmt = fmt
         self.long = long
         self.signed = signed
 
@@ -741,7 +740,8 @@ class MemoryMap:
                         addr.calculate(None, True, None))
                 offset = 0
             if isinstance(value, int):
-                self.ebpf.append(Opcode.ST + self.bits, dst, 0, offset, value)
+                self.ebpf.append(Opcode.ST + Memory.fmt_to_opcode[self.fmt],
+                                 dst, 0, offset, value)
                 return
             elif isinstance(value, IAdd):
                 value = value.value
@@ -749,18 +749,20 @@ class MemoryMap:
                     with self.ebpf.get_free_register(None) as src:
                         self.ebpf.r[src] = value
                         self.ebpf.append(
-                            Opcode.XADD + self.bits, dst, src, offset, 0)
+                            Opcode.XADD + Memory.fmt_to_opcode[self.fmt],
+                            dst, src, offset, 0)
                     return
                 opcode = Opcode.XADD
             else:
                 opcode = Opcode.STX
             with value.calculate(None, None, None) as (src, _, _):
-                self.ebpf.append(opcode + self.bits, dst, src, offset, 0)
+                self.ebpf.append(opcode + Memory.fmt_to_opcode[self.fmt],
+                                 dst, src, offset, 0)
 
     def __getitem__(self, addr):
         if isinstance(addr, Register):
             addr = addr + 0
-        return Memory(self.ebpf, self.bits, addr, self.signed, self.long)
+        return Memory(self.ebpf, self.fmt, addr, self.signed, self.long)
 
 
 class Map:
@@ -890,11 +892,11 @@ class EBPF:
             self.name = name
         self.loaded = False
 
-        self.mB = MemoryMap(self, Opcode.B)
-        self.mH = MemoryMap(self, Opcode.H)
-        self.mI = MemoryMap(self, Opcode.W)
-        self.mA = MemoryMap(self, Opcode.W, False, True)
-        self.mQ = MemoryMap(self, Opcode.DW, False, True)
+        self.mB = MemoryMap(self, "B")
+        self.mH = MemoryMap(self, "H")
+        self.mI = MemoryMap(self, "I")
+        self.mA = MemoryMap(self, "I", False, True)
+        self.mQ = MemoryMap(self, "Q", False, True)
 
         self.r = RegisterArray(self, True, False)
         self.sr = RegisterArray(self, True, True)
diff --git a/ebpfcat/ebpfcat.py b/ebpfcat/ebpfcat.py
index d10ca31..62d2ea0 100644
--- a/ebpfcat/ebpfcat.py
+++ b/ebpfcat/ebpfcat.py
@@ -6,7 +6,6 @@ from .arraymap import ArrayMap, ArrayGlobalVarDesc
 from .ethercat import ECCmd, EtherCat, Packet, Terminal
 from .ebpf import FuncId, MemoryDesc, SubProgram
 from .xdp import XDP, XDPExitCode
-from .hashmap import HashMap
 from .bpf import (
     ProgType, MapType, create_map, update_elem, prog_test_run, lookup_elem)
 
@@ -224,36 +223,21 @@ class EBPFTerminal(Terminal):
         pass
 
 
-class EBPFCat(XDP):
-    vars = HashMap()
-
-    count = vars.globalVar()
-    ptype = vars.globalVar()
-
-    def program(self):
-        #with self.If(self.packet16[12] != 0xA488):
-        #    self.exit(2)
-        self.count += 1
-        #self.ptype = self.packet32[18]
-        self.exit(2)
-
-
 class EtherXDP(XDP):
     license = "GPL"
 
-    variables = HashMap()
-    count = variables.globalVar()
-    allcount = variables.globalVar()
+    variables = ArrayMap()
+    counters = variables.globalVar("64I", write=True)
 
     def program(self):
-        e = self
-        with e.packetSize > 24 as p, p.pH[12] == 0xA488, p.pB[16] == 0:
-            e.count += 1
-            e.r2 = e.get_fd(self.programs)
-            e.r3 = p.pI[18]
-            e.call(FuncId.tail_call)
-        e.allcount += 1
-        e.exit(XDPExitCode.PASS)
+        with self.packetSize > 24 as p, p.pH[12] == 0xA488, p.pB[16] == 0:
+            self.r3 = p.pI[18]
+            with self.counters.get_address(None, False, False) as (dst, _), \
+                    self.r3 < FastEtherCat.MAX_PROGS:
+                self.mI[self.r[dst] + 4 * self.r3] += 1
+            self.r2 = self.get_fd(self.programs)
+            self.call(FuncId.tail_call)
+        self.exit(XDPExitCode.PASS)
 
 
 class SimpleEtherCat(EtherCat):
@@ -275,20 +259,34 @@ class FastEtherCat(SimpleEtherCat):
         self.programs = create_map(MapType.PROG_ARRAY, 4, 4, self.MAX_PROGS)
         self.sync_groups = {}
 
-    def register_sync_group(self, sg):
+    def register_sync_group(self, sg, packet):
         index = len(self.sync_groups)
         while index in self.sync_groups:
             index = (index + 1) % self.MAX_PROGS
         fd, _ = sg.load(log_level=1)
         update_elem(self.programs, pack("<I", index), pack("<I", fd), 0)
         self.sync_groups[index] = sg
+        sg.assembled = packet.assemble(index)
         return index
 
+    async def watchdog(self):
+        while True:
+            t0 = time()
+            self.ebpf.counters = (0,) * self.MAX_PROGS
+            self.ebpf.variables.readwrite()
+            counts = self.ebpf.counters
+            for i, sg in self.sync_groups.items():
+                if counts[i] == 0:
+                    self.send_packet(sg.assembled)
+            await sleep(0.001)
+
     async def connect(self):
         await super().connect()
         self.ebpf = EtherXDP()
         self.ebpf.programs = self.programs
         self.fd = await self.ebpf.attach(self.addr[0])
+        ensure_future(self.watchdog())
+
 
 class SyncGroupBase:
     def __init__(self, ec, devices, **kwargs):
@@ -353,8 +351,7 @@ class FastSyncGroup(SyncGroupBase, XDP):
 
     def start(self):
         self.allocate()
-        index = self.ec.register_sync_group(self)
-        self.ec.send_packet(self.packet.assemble(index))
+        self.ec.register_sync_group(self, self.packet)
         self.monitor = ensure_future(gather(*[t.to_operational()
                                               for t in self.terminals]))
         return self.monitor
diff --git a/ebpfcat/hashmap.py b/ebpfcat/hashmap.py
index 3221b6c..963afdc 100644
--- a/ebpfcat/hashmap.py
+++ b/ebpfcat/hashmap.py
@@ -29,7 +29,7 @@ class HashGlobalVar(Expression):
                                  0, 0, 0)
             else:
                 dst = 0
-        yield dst, Memory.fmt_to_opcode[self.fmt]
+        yield dst, self.fmt
 
 
 class HashGlobalVarDesc:
diff --git a/ebpfcat/xdp.py b/ebpfcat/xdp.py
index d61ab1d..53304d4 100644
--- a/ebpfcat/xdp.py
+++ b/ebpfcat/xdp.py
@@ -5,7 +5,7 @@ from socket import AF_NETLINK, NETLINK_ROUTE, if_nametoindex
 import socket
 from struct import pack, unpack
 
-from .ebpf import EBPF, Expression, Memory, Opcode, Comparison
+from .ebpf import EBPF
 from .bpf import ProgType
 
 
-- 
GitLab