From bd90b7168354e025d55724c91119ab34b5370647 Mon Sep 17 00:00:00 2001
From: Martin Teichmann <martin.teichmann@xfel.eu>
Date: Tue, 2 Jan 2024 11:16:31 +0000
Subject: [PATCH] use enums for the state machine states

---
 ebpfcat/ebpfcat.py  |  5 ++--
 ebpfcat/ethercat.py | 65 +++++++++++++++++++++++++--------------------
 ebpfcat/scripts.py  |  8 +++---
 3 files changed, 43 insertions(+), 35 deletions(-)

diff --git a/ebpfcat/ebpfcat.py b/ebpfcat/ebpfcat.py
index 015d19b..e77a346 100644
--- a/ebpfcat/ebpfcat.py
+++ b/ebpfcat/ebpfcat.py
@@ -27,7 +27,8 @@ from struct import pack, unpack, calcsize, pack_into, unpack_from
 from time import time
 from .arraymap import ArrayMap, ArrayGlobalVarDesc
 from .ethercat import (
-    ECCmd, EtherCat, Packet, Terminal, EtherCatError, SyncManager)
+    ECCmd, EtherCat, MachineState, Packet, Terminal, EtherCatError,
+    SyncManager)
 from .ebpf import FuncId, MemoryDesc, SubProgram, prandom
 from .xdp import XDP, XDPExitCode, PacketVar as XDPPacketVar
 from .bpf import (
@@ -250,7 +251,7 @@ class EBPFTerminal(Terminal):
                 (self.vendorId, self.productCode) not in self.compatibility):
             raise EtherCatError(
                 f"Incompatible Terminal: {self.vendorId}:{self.productCode}")
-        await self.to_operational(2)
+        await self.to_operational(MachineState.PRE_OPERATIONAL)
         self.pdos = {}
         outbits, inbits = await self.parse_pdos()
         self.pdo_out_sz = int((outbits + 7) // 8)
diff --git a/ebpfcat/ethercat.py b/ebpfcat/ethercat.py
index ad8d39d..397cd56 100644
--- a/ebpfcat/ethercat.py
+++ b/ebpfcat/ethercat.py
@@ -135,6 +135,19 @@ class EEPROM(IntEnum):
     REVISION = 12
     SERIAL_NO = 14
 
+class MachineState(Enum):
+    """The states of the EtherCAT state machine
+
+    The states are in the order in which they should
+    be taken, BOOTSTRAP is at the end as this is a
+    state we usually do not go to.
+    """
+    INIT = 1
+    PRE_OPERATIONAL = 2
+    SAFE_OPERATIONAL = 4
+    OPERATIONAL = 8
+    BOOTSTRAP = 3
+
 class SyncManager(Enum):
     OUT = 2
     IN = 3
@@ -561,43 +574,37 @@ class Terminal:
         return ret
 
     async def get_state(self):
-        """get the current state and error flags"""
-        state, error = await self.ec.roundtrip(ECCmd.FPRD, self.position,
-                                               0x0130, "H2xH")
-        return state, error
+        """get the current state, error flag and status word"""
+        state, status = await self.ec.roundtrip(ECCmd.FPRD, self.position,
+                                                0x0130, "H2xH")
+        return MachineState(state & 0xf), bool(state & 0x10), status
 
-    async def to_operational(self, target=8):
+    async def to_operational(self, target=MachineState.OPERATIONAL):
         """try to bring the terminal to operational state
 
         this tries to push the terminal through its state machine to the
-        operational state. Note that even if it reaches there, the terminal
+        target state. Note that even if it reaches there, the terminal
         will quickly return to pre-operational if no packets are sent to keep
-        it operational. """
-        order = [1, 2, 4, 8]
-        ret, error = await self.ec.roundtrip(
-                ECCmd.FPRD, self.position, 0x0130, "H2xH")
-        if ret & 0x10:
+        it operational.
+
+        return the state, error flag and status before the operation."""
+        order = list(MachineState)
+        state, error, status = ret = await self.get_state()
+        if error:
             await self.ec.roundtrip(ECCmd.FPWR, self.position,
                                     0x0120, "H", 0x11)
-            ret, error = await self.ec.roundtrip(ECCmd.FPRD, self.position,
-                                                 0x0130, "H2xH")
-        pos = order.index(ret)
-        s = 0x11
-        for state in order[pos+1:]:
+            state = MachineState.INIT
+        pos = order.index(state) + 1
+        state = None
+        for current in order[pos:]:
             await self.ec.roundtrip(ECCmd.FPWR, self.position,
-                                    0x0120, "H", state)
-            while s != state:
-                s, error = await self.ec.roundtrip(ECCmd.FPRD, self.position,
-                                                   0x0130, "H2xH")
-                if error != 0:
-                    raise EtherCatError(f"AL register {error}")
-            if state >= target:
-                return
-
-    async def get_error(self):
-        """read the error register"""
-        return (await self.ec.roundtrip(ECCmd.FPRD, self.position,
-                                        0x0134, "H"))[0]
+                                    0x0120, "H", current.value)
+            while current is not state:
+                state, error, status = await self.get_state()
+                if error:
+                    raise EtherCatError(f"AL register error {status}")
+            if state.value >= target.value:
+                return ret
 
     async def read(self, start, *args, **kwargs):
         """read data from the terminal at offset `start`
diff --git a/ebpfcat/scripts.py b/ebpfcat/scripts.py
index 7b2de85..0dd6d50 100644
--- a/ebpfcat/scripts.py
+++ b/ebpfcat/scripts.py
@@ -6,7 +6,7 @@ from pprint import PrettyPrinter
 from struct import unpack
 import sys
 
-from .ethercat import EtherCat, Terminal, ECCmd, EtherCatError
+from .ethercat import EtherCat, MachineState, Terminal, ECCmd, EtherCatError
 
 def entrypoint(func):
     @wraps(func)
@@ -73,7 +73,7 @@ async def info():
                 print(f"{k:2}: {v}\n    {v.hex()}")
 
         if args.sdo:
-            await t.to_operational(2)
+            await t.to_operational(MachineState.PRE_OPERATIONAL)
             ret = await t.read_ODlist()
             for k, v in ret.items():
                 print(f"{k:X}:")
@@ -91,7 +91,7 @@ async def info():
                                 print(f"        {r}")
                                 print(f"        {r!r}")
         if args.pdo:
-            await t.to_operational(2)
+            await t.to_operational(MachineState.PRE_OPERATIONAL)
             await t.parse_pdos()
             for (idx, subidx), (sm, pos, fmt) in t.pdos.items():
                 print(f"{idx:4X}:{subidx:02X} {sm} {pos} {fmt}")
@@ -168,7 +168,7 @@ async def create_test():
         await t.initialize(-i, await ec.find_free_address())
         sdo = {}
         if t.has_mailbox():
-            await t.to_operational(2)
+            await t.to_operational(MachineState.PRE_OPERATIONAL)
             odlist = await t.read_ODlist()
 
             for k, v in odlist.items():
-- 
GitLab