From c0b9a8011e5f09063f577d968db9c415a4460a8c Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@gmail.com>
Date: Mon, 31 Jul 2023 18:15:07 +0100
Subject: [PATCH] factor out a SterilePacket for better testability

---
 ebpfcat/ebpfcat.py       | 36 +++++++++++++++++++++++++-----------
 ebpfcat/ethercat_test.py | 19 ++++++++++++++++++-
 2 files changed, 43 insertions(+), 12 deletions(-)

diff --git a/ebpfcat/ebpfcat.py b/ebpfcat/ebpfcat.py
index 6108600..c54c441 100644
--- a/ebpfcat/ebpfcat.py
+++ b/ebpfcat/ebpfcat.py
@@ -264,9 +264,8 @@ class EBPFTerminal(Terminal):
                           self.position, self.pdo_in_off)
         if readwrite and self.pdo_out_sz:
             bases[2] = packet.size + packet.DATAGRAM_HEADER
-            packet.on_the_fly.append((packet.size, ECCmd.FPWR))
-            packet.append(ECCmd.FPWR, b"\0" * self.pdo_out_sz, 0,
-                          self.position, self.pdo_out_off)
+            packet.append_writer(ECCmd.FPWR, b"\0" * self.pdo_out_sz, 0,
+                                 self.position, self.pdo_out_off)
         return bases
 
     def update(self, data):
@@ -376,6 +375,26 @@ class FastEtherCat(SimpleEtherCat):
                     v.cancel()
 
 
+class SterilePacket(Packet):
+    """a sterile packet has all its sets exchanges by NOPs"""
+    def __init__(self):
+        super().__init__()
+        self.on_the_fly = []  # list of sterilized positions
+
+    def append_writer(self, cmd, *args):
+        self.on_the_fly.append((self.size, cmd))
+        super().append(cmd, *args)
+
+    def sterile(self, index):
+        ret = bytearray(self.assemble(index))
+        for pos, cmd in self.on_the_fly:
+            ret[pos] = ECCmd.NOP.value
+        return ret
+
+    def activate(self, ebpf):
+        for pos, cmd in self.on_the_fly:
+            ebpf.pB[pos + self.ETHERNET_HEADER] = cmd.value
+
 class SyncGroupBase:
     missed_counter = 0
 
@@ -413,8 +432,7 @@ class SyncGroupBase:
             data = self.update_devices(data)
 
     def allocate(self):
-        self.packet = Packet()
-        self.packet.on_the_fly = []
+        self.packet = SterilePacket()
         self.terminals = {t: t.allocate(self.packet, rw)
                           for t, rw in self.terminals.items()}
 
@@ -450,18 +468,14 @@ class FastSyncGroup(SyncGroupBase, XDP):
 
     def program(self):
         with self.packetSize >= self.packet.size + Packet.ETHERNET_HEADER as p:
-            for pos, cmd in self.packet.on_the_fly:
-                p.pB[pos + Packet.ETHERNET_HEADER] = cmd.value
+            self.packet.activate(p)
             for dev in self.devices:
                 dev.program()
         self.exit(XDPExitCode.TX)
 
     async def run(self):
         with self.ec.register_sync_group(self) as self.packet_index:
-            self.asm_packet = bytearray(
-                self.packet.assemble(self.packet_index))
-            for pos, cmd in self.packet.on_the_fly:
-                self.asm_packet[pos] = ECCmd.NOP.value
+            self.asm_packet = self.packet.sterile(self.packet_index)
             # prime the pump: two packets to get things going
             self.ec.send_packet(self.asm_packet)
             self.ec.send_packet(self.asm_packet)
diff --git a/ebpfcat/ethercat_test.py b/ebpfcat/ethercat_test.py
index 79fe8fe..d3f1a40 100644
--- a/ebpfcat/ethercat_test.py
+++ b/ebpfcat/ethercat_test.py
@@ -28,7 +28,7 @@ from .terminals import EL4104, EL3164, EK1814, Skip
 from .ethercat import ECCmd, Terminal
 from .ebpfcat import (
     FastSyncGroup, SyncGroup, TerminalVar, Device, EBPFTerminal, PacketDesc,
-    EtherCatBase)
+    EtherCatBase, SterilePacket)
 from .ebpf import Instruction, Opcode as O
 
 
@@ -395,5 +395,22 @@ class Tests(TestCase):
         self.assertEqual(sg.opcodes, [])
 
 
+class UnitTests(TestCase):
+    def test_sterile(self):
+        p = SterilePacket()
+        p.append(ECCmd.LRD, b"asdf", 0x33, 0x654321)
+        p.append_writer(ECCmd.FPRD, b"fdsa", 0x44, 0x55, 0x66)
+        self.assertEqual(p.assemble(0x77),
+                         H("2e10"
+                           "0000770000000280000000000000"
+                           "0a332143650004800000617364660000"
+                           "04445500660004000000666473610000"))
+        self.assertEqual(p.sterile(0x77),
+                         H("2e10"
+                           "0000770000000280000000000000"
+                           "0a332143650004800000617364660000"
+                           "00445500660004000000666473610000"))
+        self.assertEqual(p.on_the_fly, [(32, ECCmd.FPRD)])
+
 if __name__ == "__main__":
     main()
-- 
GitLab