diff --git a/ebpfcat/ebpfcat.py b/ebpfcat/ebpfcat.py index 290556973b36c73653b0b11999fabb3945887296..b83f14339324a308fdb19dd77caa3e9ffe55fdbd 100644 --- a/ebpfcat/ebpfcat.py +++ b/ebpfcat/ebpfcat.py @@ -16,12 +16,15 @@ # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. """The high-level API for EtherCAT loops""" +import asyncio from asyncio import ( - CancelledError, ensure_future, gather, sleep, wait_for, TimeoutError) + CancelledError, TimeoutError, ensure_future, gather, + get_event_loop, sleep, wait_for) from collections import defaultdict from contextlib import asynccontextmanager, AsyncExitStack, contextmanager from enum import Enum import logging +from multiprocessing import Array, Process, Value, get_context import os from random import randint import shutil @@ -32,7 +35,8 @@ from .arraymap import ArrayMap, ArrayGlobalVarDesc from .ethercat import ( ECCmd, EtherCat, MachineState, Packet, Terminal, EtherCatError, Struct, SyncManager) -from .ebpf import FuncId, MemoryDesc, SubProgram, prandom +from .ebpf import ( + EBPFBase, FuncId, MemoryDesc, SimulatedEBPF, SubProgram, prandom) from .lock import ParallelMailboxLock, LockFile from .xdp import XDP, XDPExitCode, PacketVar as XDPPacketVar from .bpf import ( @@ -238,13 +242,13 @@ class DeviceVar(ArrayGlobalVarDesc): def __get__(self, instance, owner): if instance is None: return self - elif isinstance(instance.sync_group, FastSyncGroup): + elif isinstance(instance.sync_group, EBPFBase): return super().__get__(instance, owner) else: return instance.__dict__.get(self.name, 0) def __set__(self, instance, value): - if isinstance(instance.sync_group, FastSyncGroup): + if isinstance(instance.sync_group, EBPFBase): super().__set__(instance, value) else: instance.__dict__[self.name] = value @@ -534,6 +538,12 @@ class ParallelEtherCat(FastEtherCat): os.remove(programs) self.mbx_lock_file.remove() + def __getstate__(self): + return self.addr[0] + + def __setstate__(self, network): + self.__init__(network) + class SterilePacket(Packet): """a sterile packet has all its sets exchanged by NOPs""" @@ -588,6 +598,7 @@ class BaseType(Enum): class SyncGroupBase: missed_counter = 0 + running = True cycletime = 0.01 # cycle time of the PLC loop task = None @@ -651,7 +662,7 @@ class SyncGroupBase: task = ensure_future(self.to_operational()) try: lasttime = monotonic() - while True: + while self.running: self.ec.send_packet(data) try: data = await wait_for( @@ -706,7 +717,7 @@ class SyncGroup(SyncGroupBase): packet_index = 1000 def update_devices(self, data): - self.current_data = bytearray(data) + self.current_data[:] = data for pos, counts in self.packet.counters.items(): if data[pos] not in counts: logging.warning( @@ -725,10 +736,81 @@ class SyncGroup(SyncGroupBase): SyncGroup.packet_index += 1 self.asm_packet = self.packet.assemble(self.packet_index, self.ec.ethertype) + self.current_data = bytearray(self.asm_packet) self.task = ensure_future(self.run()) return self.task +class ProcessSyncGroup(SyncGroup, SimulatedEBPF): + + properties = ArrayMap() + + def __init__(self, ec, devices, **kwargs): + self.ctx = get_context('spawn') + super().__init__(ec, devices, subprograms=devices, **kwargs) + + def get_array(self, size): + return self.ctx.Array('B', size).get_obj() + + @property + def running(self): + return self.runningValue.value + + def subprocess_run(self): + asyncio.run(self.subprocess_loop()) + + async def subprocess_loop(self): + async with self.ec.run(): + self.asm_packet = self.packet.assemble(self.packet_index, + self.ec.ethertype) + print('just before run') + try: + await self.run() + except Exception as e: + print('error in run', e) + raise + finally: + print('after run') + + async def wait_for_process(self): + fd = os.pidfd_open(self.process.pid) + loop = get_event_loop() + error = None + while True: + future = loop.create_future() + loop.add_reader(fd, future.set_result, None) + try: + await future + except CancelledError as error: + self.runningValue.value = False + else: + if error is None: + print('process terminated', self.process.is_alive()) + return + else: + raise error + finally: + loop.remove_reader(fd) + + @property + def current_data(self): + return memoryview(self._current_data.get_obj()).cast('B') + + def start(self): + assert isinstance(self.ec, ParallelEtherCat) + self.runningValue = self.ctx.Value('B') + self.runningValue.value = True + self.allocate() + self.packet_index = SyncGroup.packet_index + SyncGroup.packet_index += 1 + self.task = None + self._current_data = self.ctx.Array('B', max(46, self.packet.size)) + self.process = self.ctx.Process(target=self.subprocess_run) + self.process.start() + self.task = ensure_future(self.wait_for_process()) + return self.task + + class FastSyncGroup(SyncGroupBase, XDP): license = "GPL"