From 15d2116f299cefb556b544eba4c8102567d36cb0 Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Wed, 17 Nov 2021 23:35:10 +0100
Subject: [PATCH] Restructure input handler, DRY slightly

---
 src/calng/AgipdCorrection.py | 176 ++++++++++++-----------------------
 src/calng/DsscCorrection.py  | 138 ++++++++++-----------------
 src/calng/base_correction.py | 139 +++++++++++++++++++++++----
 3 files changed, 227 insertions(+), 226 deletions(-)

diff --git a/src/calng/AgipdCorrection.py b/src/calng/AgipdCorrection.py
index e4d4be15..19169e26 100644
--- a/src/calng/AgipdCorrection.py
+++ b/src/calng/AgipdCorrection.py
@@ -1,5 +1,3 @@
-import timeit
-
 import numpy as np
 from karabo.bound import (
     BOOL_ELEMENT,
@@ -36,14 +34,12 @@ class AgipdCorrection(BaseCorrection):
     _calcat_friend_class = AgipdCalcatFriend
     _constant_enum_class = AgipdConstants
     _managed_keys = BaseCorrection._managed_keys | {
-        "overrideInputAxisOrder",
         "sendGainMap",
     }
 
     # this is just extending (not mandatory)
     _schema_cache_fields = BaseCorrection._schema_cache_fields | {
         "sendGainMap",
-        "overrideInputAxisOrder",
     }
 
     @staticmethod
@@ -52,13 +48,6 @@ class AgipdCorrection(BaseCorrection):
         expected.setDefaultValue("dataFormat.memoryCells", 352)
         expected.setDefaultValue("preview.selectionMode", "cell")
         (
-            BOOL_ELEMENT(expected)
-            .key("overrideInputAxisOrder")
-            .displayedName("Override input axis order")
-            .assignmentOptional()
-            .defaultValue(False)
-            .reconfigurable()
-            .commit(),
             STRING_ELEMENT(expected)
             .key("gainMode")
             .displayedName("Gain mode")
@@ -222,7 +211,6 @@ class AgipdCorrection(BaseCorrection):
         }
 
         self._shmem_buffer_gain_map = None
-        self._update_shapes()
 
         # configurability: overriding md_additional_offset
         if config.get("corrections.relGainPc.overrideMdAdditionalOffset"):
@@ -234,125 +222,89 @@ class AgipdCorrection(BaseCorrection):
 
         # configurability: disabling subset of bad pixel masking bits
         self._has_updated_bad_pixel_selection = False
-        self._update_bad_pixel_selection()
+        self.registerInitialFunction(self._update_bad_pixel_selection)
 
         self.updateState(State.ON)
 
-    def process_input(self, data, metadata):
-        """Registered for dataInput, handles all processing and sending"""
-
-        source = metadata.get("source")
-
-        if source not in self.sources:
-            self.log_status_info(f"Ignoring hash with unknown source {source}")
-            return
-
-        if not data.has("image"):
-            self.log_status_info("Ignoring hash without image node")
-            return
+    def process_data(
+        self,
+        data_hash,
+        metadata,
+        source,
+        train_id,
+        image_data,
+        cell_table,
+        do_generate_preview,
+    ):
+        """Called by input_handler for each data hash. Should correct data, optionally
+        compute preview, write data output, and optionally write preview outputs."""
+        # original shape: memory_cell, data/raw_gain, x, y
 
-        time_start = timeit.default_timer()
-        self._last_processing_started = time_start
+        # TODO: add pulse filter back in
+        pulse_table = np.squeeze(data_hash.get("image.pulseId"))
 
-        train_id = metadata.getAttribute("timestamp", "tid")
-        cell_table = np.squeeze(data.get("image.cellId"))
-        if len(cell_table.shape) == 0:
-            self.log_status_warn(
-                "cellId had 0 dimensions. DAQ may not be sending data."
-            )
+        try:
+            self.gpu_runner.load_data(image_data)
+        except ValueError as e:
+            self.log_status_warn(f"Failed to load data: {e}")
             return
-        # original shape: memory_cell, data/raw_gain, x, y
-        image_data = data.get("image.data")
-        if image_data.shape[0] != self._schema_cache["dataFormat.memoryCells"]:
-            self.log_status_info(
-                f"Updating input shapes based on received {image_data.shape}"
-            )
-            self.set("dataFormat.memoryCells", image_data.shape[0])
-            with self._buffer_lock:
-                # TODO: pulse filter update after reimplementation
-                self._update_shapes()
-
-        if not self._schema_cache["state"] is State.PROCESSING:
-            self.updateState(State.PROCESSING)
-            self.log_status_info("Processing data")
-
-        correction_cell_num = self._schema_cache["constantParameters.memoryCells"]
-        do_generate_preview = (
-            train_id % self._schema_cache["preview.trainIdModulo"] == 0
-            and self._schema_cache["preview.enable"]
+        except Exception as e:
+            self.log_status_warn(f"Unknown exception when loading data to GPU: {e}")
+
+        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
+        self.gpu_runner.load_cell_table(cell_table)
+        self.gpu_runner.correct(self._correction_flag_enabled)
+        self.gpu_runner.reshape(
+            output_order=self._schema_cache["dataFormat.outputAxisOrder"],
+            out=buffer_array,
         )
-
-        if self._schema_cache["overrideInputAxisOrder"]:
-            expected_shape = self.input_data_shape
-            if expected_shape != image_data.shape:
-                image_data.shape = expected_shape
-
-        with self._buffer_lock:
-            # cell_table = cell_table[self.pulse_filter]
-            pulse_table = np.squeeze(data.get("image.pulseId"))  # [self.pulse_filter]
-            cell_table_max = np.max(cell_table)
-            if cell_table_max >= correction_cell_num:
-                self.log_status_info(
-                    f"Max cell ID ({cell_table_max}) exceeds range for loaded "
-                    f"constants ({correction_cell_num} cells). Some frames will not be "
-                    "corrected."
-                )
-
-            try:
-                self.gpu_runner.load_data(image_data)
-            except ValueError as e:
-                self.log_status_warn(f"Failed to load data: {e}")
-                return
-
-            buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-            self.gpu_runner.load_cell_table(cell_table)
-            self.gpu_runner.correct(self._correction_flag_enabled)
-            self.gpu_runner.reshape(
-                output_order=self._schema_cache["dataFormat.outputAxisOrder"],
-                out=buffer_array,
+        # after reshape, data for dataOutput is now safe in its own buffer
+        if do_generate_preview:
+            if self._correction_flag_enabled != self._correction_flag_preview:
+                self.gpu_runner.correct(self._correction_flag_preview)
+            (
+                preview_slice_index,
+                preview_cell,
+                preview_pulse,
+            ) = utils.pick_frame_index(
+                self._schema_cache["preview.selectionMode"],
+                self._schema_cache["preview.index"],
+                cell_table,
+                pulse_table,
+                warn_func=self.log_status_warn,
             )
-            # after reshape, data for dataOutput is now safe in its own buffer
-            if do_generate_preview:
-                if self._correction_flag_enabled != self._correction_flag_preview:
-                    self.gpu_runner.correct(self._correction_flag_preview)
-                (
-                    preview_slice_index,
-                    preview_cell,
-                    preview_pulse,
-                ) = utils.pick_frame_index(
-                    self._schema_cache["preview.selectionMode"],
-                    self._schema_cache["preview.index"],
-                    cell_table,
-                    pulse_table,
-                    warn_func=self.log_status_warn,
-                )
-                preview_raw, preview_corrected = self.gpu_runner.compute_preview(
+            (
+                preview_raw,
+                preview_corrected,
+            ) = self.gpu_runner.compute_preview(preview_slice_index)
+            if self._schema_cache["sendGainMap"]:
+                preview_gain = self.gpu_runner.compute_preview_gain(
                     preview_slice_index
                 )
-                if self._schema_cache["sendGainMap"]:
-                    preview_gain = self.gpu_runner.compute_preview_gain(
-                        preview_slice_index
-                    )
 
-        data.set("image.data", buffer_handle)
+        # reusing input data hash for sending
+        data_hash.set("image.data", buffer_handle)
         if self._schema_cache["sendGainMap"]:
-            buffer_handle, buffer_array = self._shmem_buffer_gain_map.next_slot()
+            (
+                buffer_handle,
+                buffer_array,
+            ) = self._shmem_buffer_gain_map.next_slot()
             self.gpu_runner.get_gain_map(
                 output_order=self._schema_cache["dataFormat.outputAxisOrder"],
                 out=buffer_array,
             )
-            data.set(
+            data_hash.set(
                 "image.gainMap",
                 buffer_handle,
             )
-            data.set("calngShmemPaths", ["image.data", "image.gainMap"])
+            data_hash.set("calngShmemPaths", ["image.data", "image.gainMap"])
         else:
-            data.set("calngShmemPaths", ["image.data"])
+            data_hash.set("calngShmemPaths", ["image.data"])
 
-        data.set("image.cellId", cell_table[:, np.newaxis])
-        data.set("image.pulseId", pulse_table[:, np.newaxis])
+        data_hash.set("image.cellId", cell_table[:, np.newaxis])
+        data_hash.set("image.pulseId", pulse_table[:, np.newaxis])
 
-        self._write_output(data, metadata)
+        self._write_output(data_hash, metadata)
         if do_generate_preview:
             if self._schema_cache["sendGainMap"]:
                 self._write_combiner_previews(
@@ -375,12 +327,6 @@ class AgipdCorrection(BaseCorrection):
                     source,
                 )
 
-        # update rate etc.
-        self._buffered_status_update.set("trainId", train_id)
-        self._rate_tracker.update()
-        time_spent = timeit.default_timer() - time_start
-        self._processing_time_ema.update(time_spent)
-
     def _load_constant_to_gpu(self, constant, constant_data):
         # TODO: encode correction / constant dependencies in a clever way
         if constant is AgipdConstants.ThresholdsDark:
diff --git a/src/calng/DsscCorrection.py b/src/calng/DsscCorrection.py
index 2b74bed7..2db4d915 100644
--- a/src/calng/DsscCorrection.py
+++ b/src/calng/DsscCorrection.py
@@ -1,5 +1,3 @@
-import timeit
-
 import numpy as np
 from karabo.bound import KARABO_CLASSINFO, VECTOR_STRING_ELEMENT
 from karabo.common.states import State
@@ -53,93 +51,57 @@ class DsscCorrection(BaseCorrection):
         super().__init__(config)
         self.updateState(State.ON)
 
-    def process_input(self, data, metadata):
-        """Registered for dataInput, handles all processing and sending"""
-
-        source = metadata.get("source")
-
-        if source not in self.sources:
-            self.log_status_info(f"Ignoring hash with unknown source {source}")
-            return
-
-        if not data.has("image"):
-            self.log_status_info("Ignoring hash without image node")
-            return
-
-        time_start = timeit.default_timer()
-        self._last_processing_started = time_start
-
-        train_id = metadata.getAttribute("timestamp", "tid")
-        cell_table = np.squeeze(data.get("image.cellId"))
-        assert isinstance(cell_table, np.ndarray), "image.cellId should be ndarray"
-        if len(cell_table.shape) == 0:
-            self.log_status_warn(
-                "cellId had 0 dimensions. DAQ may not be sending data."
-            )
+    def process_data(
+        self,
+        data_hash,
+        metadata,
+        source,
+        train_id,
+        image_data,
+        cell_table,
+        do_generate_preview,
+    ):
+        # cell_table = cell_table[self.pulse_filter]
+        pulse_table = np.squeeze(data_hash.get("image.pulseId"))  # [self.pulse_filter]
+
+        try:
+            self.gpu_runner.load_data(image_data)
+        except ValueError as e:
+            self.log_status_warn(f"Failed to load data: {e}")
             return
-        # original shape: 400, 1, 128, 512 (memory cells, something, y, x)
-        image_data = data.get("image.data")
-        if image_data.shape[0] != self._schema_cache["dataFormat.memoryCells"]:
-            self.log_status_info(
-                f"Updating input shapes based on received {image_data.shape}"
-            )
-            self.set("dataFormat.memoryCells", image_data.shape[0])
-            with self._buffer_lock:
-                # TODO: pulse filter update after reimplementation
-                self._update_shapes()
-
-        if not self._schema_cache["state"] is State.PROCESSING:
-            self.updateState(State.PROCESSING)
-            self.log_status_info("Processing data")
-
-        correction_cell_num = self._schema_cache["constantParameters.memoryCells"]
-        do_generate_preview = (
-            train_id % self._schema_cache["preview.trainIdModulo"] == 0
-            and self._schema_cache["preview.enable"]
+        except Exception as e:
+            self.log_status_warn(f"Unknown exception when loading data to GPU: {e}")
+
+        buffer_handle, buffer_array = self._shmem_buffer.next_slot()
+        self.gpu_runner.load_cell_table(cell_table)
+        self.gpu_runner.correct(self._correction_flag_enabled)
+        self.gpu_runner.reshape(
+            output_order=self._schema_cache["dataFormat.outputAxisOrder"],
+            out=buffer_array,
         )
-
-        with self._buffer_lock:
-            # cell_table = cell_table[self.pulse_filter]
-            pulse_table = np.squeeze(data.get("image.pulseId"))  # [self.pulse_filter]
-            cell_table_max = np.max(cell_table)
-            if cell_table_max >= correction_cell_num:
-                self.log_status_info(
-                    f"Max cell ID ({cell_table_max}) exceeds range for loaded "
-                    f"constant ({correction_cell_num} cells). Some frames will not be "
-                    "corrected."
-                )
-
-            self.gpu_runner.load_data(image_data)
-            buffer_handle, buffer_array = self._shmem_buffer.next_slot()
-            self.gpu_runner.load_cell_table(cell_table)
-            self.gpu_runner.correct(self._correction_flag_enabled)
-            self.gpu_runner.reshape(
-                output_order=self._schema_cache["dataFormat.outputAxisOrder"],
-                out=buffer_array,
+        if do_generate_preview:
+            if self._correction_flag_enabled != self._correction_flag_preview:
+                self.gpu_runner.correct(self._correction_flag_preview)
+            (
+                preview_slice_index,
+                preview_cell,
+                preview_pulse,
+            ) = utils.pick_frame_index(
+                self._schema_cache["preview.selectionMode"],
+                self._schema_cache["preview.index"],
+                cell_table,
+                pulse_table,
+                warn_func=self.log_status_warn,
+            )
+            preview_raw, preview_corrected = self.gpu_runner.compute_preview(
+                preview_slice_index,
             )
-            if do_generate_preview:
-                if self._correction_flag_enabled != self._correction_flag_preview:
-                    self.gpu_runner.correct(self._correction_flag_preview)
-                (
-                    preview_slice_index,
-                    preview_cell,
-                    preview_pulse,
-                ) = utils.pick_frame_index(
-                    self._schema_cache["preview.selectionMode"],
-                    self._schema_cache["preview.index"],
-                    cell_table,
-                    pulse_table,
-                    warn_func=self.log_status_warn,
-                )
-                preview_raw, preview_corrected = self.gpu_runner.compute_preview(
-                    preview_slice_index,
-                )
 
-        data.set("image.data", buffer_handle)
-        data.set("image.cellId", cell_table[:, np.newaxis])
-        data.set("image.pulseId", pulse_table[:, np.newaxis])
-        data.set("calngShmemPaths", ["image.data"])
-        self._write_output(data, metadata)
+        data_hash.set("image.data", buffer_handle)
+        data_hash.set("image.cellId", cell_table[:, np.newaxis])
+        data_hash.set("image.pulseId", pulse_table[:, np.newaxis])
+        data_hash.set("calngShmemPaths", ["image.data"])
+        self._write_output(data_hash, metadata)
         if do_generate_preview:
             self._write_combiner_previews(
                 (
@@ -150,12 +112,6 @@ class DsscCorrection(BaseCorrection):
                 source,
             )
 
-        # update rate etc.
-        self._buffered_status_update.set("trainId", train_id)
-        self._rate_tracker.update()
-        time_spent = timeit.default_timer() - time_start
-        self._processing_time_ema.update(time_spent)
-
     def _update_pulse_filter(self, filter_string):
         """Called whenever the pulse filter changes, typically followed by
         _update_shapes"""
diff --git a/src/calng/base_correction.py b/src/calng/base_correction.py
index 2dccd1e2..47f695f0 100644
--- a/src/calng/base_correction.py
+++ b/src/calng/base_correction.py
@@ -1,12 +1,12 @@
 import pathlib
 import threading
-import timeit
+from timeit import default_timer
 
 import dateutil.parser
 import numpy as np
 from karabo.bound import (
     BOOL_ELEMENT,
-    FLOAT_ELEMENT,
+    DOUBLE_ELEMENT,
     INPUT_CHANNEL,
     INT32_ELEMENT,
     INT64_ELEMENT,
@@ -50,6 +50,7 @@ class BaseCorrection(PythonDevice):
         "outputShmemBufferSize",
         "dataFormat.outputAxisOrder",
         "dataFormat.outputImageDtype",
+        "dataFormat.overrideInputAxisOrder",
         "preview.enable",
         "preview.index",
         "preview.selectionMode",
@@ -63,6 +64,7 @@ class BaseCorrection(PythonDevice):
         "dataFormat.pixelsX",
         "dataFormat.pixelsY",
         "dataFormat.outputAxisOrder",
+        "dataFormat.overrideInputAxisOrder",
         "preview.enable",
         "preview.index",
         "preview.selectionMode",
@@ -235,6 +237,20 @@ class BaseCorrection(PythonDevice):
             .key("dataFormat")
             .displayedName("Data format (in/out)")
             .commit(),
+            BOOL_ELEMENT(expected)
+            .key("dataFormat.overrideInputAxisOrder")
+            .displayedName("Override input axis order")
+            .description(
+                "The shape of the image data ndarray as received from the "
+                "DataAggregator is sometimes wrong - the axes are actually in a "
+                "different order than the ndarray shape suggests. If this flag is on, "
+                "the shape of the ndarray will be overridden with the axis order we "
+                "expect."
+            )
+            .assignmentOptional()
+            .defaultValue(True)
+            .reconfigurable()
+            .commit(),
             STRING_ELEMENT(expected)
             .key("dataFormat.inputImageDtype")
             .displayedName("Input image data dtype")
@@ -368,8 +384,8 @@ class BaseCorrection(PythonDevice):
             .description(
                 "The value of preview.index can be used in multiple ways, controlled "
                 "by this value. If this is set to 'frame', preview.index is sliced "
-                "directly from data. If 'cell' (or 'pulse') is selected, I will look at "
-                "cell (or pulse) table for the requested cell (or pulse ID)."
+                "directly from data. If 'cell' (or 'pulse') is selected, I will look "
+                "at cell (or pulse) table for the requested cell (or pulse ID)."
             )
             .options("frame,cell,pulse")
             .assignmentOptional()
@@ -404,13 +420,9 @@ class BaseCorrection(PythonDevice):
             .key("performance")
             .displayedName("Performance measures")
             .commit(),
-            FLOAT_ELEMENT(expected)
-            .key("performance.processingDuration")
+            DOUBLE_ELEMENT(expected)
+            .key("performance.processingTime")
             .displayedName("Processing time")
-            .description(
-                "Exponential moving average over time spent processing individual "
-                "trains Time includes generating preview and sending data."
-            )
             .unit(Unit.SECOND)
             .metricPrefix(MetricPrefix.MILLI)
             .readOnly()
@@ -419,7 +431,7 @@ class BaseCorrection(PythonDevice):
             .info("Processing not fast enough for full speed")
             .needsAcknowledging(False)
             .commit(),
-            FLOAT_ELEMENT(expected)
+            DOUBLE_ELEMENT(expected)
             .key("performance.rate")
             .displayedName("Rate")
             .description(
@@ -445,20 +457,22 @@ class BaseCorrection(PythonDevice):
             k: config.get(k) for k in self._schema_cache_fields if config.has(k)
         }
         super().__init__(config)
+        self.updateState(State.INIT)
 
         if not sorted(config.get("dataFormat.outputAxisOrder")) == ["c", "x", "y"]:
             # TODO: figure out how to get this information to operator
             self.log_status_error("Invalid output axis order string")
             return
 
-        self.KARABO_ON_DATA("dataInput", self.process_input)
+        self.KARABO_ON_INPUT("dataInput", self.input_handler)
         self.KARABO_ON_EOS("dataInput", self.handle_eos)
 
         self.sources = set(config.get("fastSources"))
 
         self.input_data_dtype = np.dtype(config.get("dataFormat.inputImageDtype"))
         self.output_data_dtype = np.dtype(config.get("dataFormat.outputImageDtype"))
-        self.gpu_runner = None  # must call _update_shapes() in subclass init
+        self.gpu_runner = None  # must call _update_shapes()
+        self.registerInitialFunction(self._update_shapes)
 
         self.calcat_friend = self._calcat_friend_class(
             self, pathlib.Path.cwd() / "calibration-client-secrets.json"
@@ -477,10 +491,10 @@ class BaseCorrection(PythonDevice):
             0,
             "performance.rate",
             0,
-            "performance.processingDuration",
+            "performance.processingTime",
             0,
         )
-        self._last_processing_started = 0  # not input handler should put timestamp
+        self._last_processing_started = 0  # input handler should put timestamp
         self._rate_update_timer = utils.RepeatingTimer(
             interval=1,
             callback=self._update_rate_and_state,
@@ -697,18 +711,103 @@ class BaseCorrection(PythonDevice):
 
         self._has_updated_shapes = True
 
+    def input_handler(self, input_channel):
+        """Main handler for data input: Do a few simple checks to determine whether to
+        even try processing. If yes, will pass data and information to subclass'
+        process_data function.
+        """
+
+        # Is device even ready for this?
+        state = self._schema_cache["state"]
+        if state is State.ERROR:
+            # in this case, we should have already issued warning
+            return
+        elif self.gpu_runner is None:
+            self.log_status_warn("Received data, but have not initialized kernels yet")
+            return
+
+        all_metadata = input_channel.getMetaData()
+        for input_index in range(input_channel.size()):
+            self._last_processing_started = default_timer()
+            data_hash = input_channel.read(input_index)
+            metadata = all_metadata[input_index]
+            source = metadata.get("source")
+
+            if source not in self.sources:
+                self.log_status_info(f"Ignoring hash with unknown source {source}")
+                return
+            elif not data_hash.has("image"):
+                self.log_status_info("Ignoring hash without image node")
+                return
+
+            train_id = metadata.getAttribute("timestamp", "tid")
+            cell_table = np.squeeze(data_hash.get("image.cellId"))
+            if len(cell_table.shape) == 0:
+                self.log_status_warn(
+                    "cellId had 0 dimensions. DAQ may not be sending data."
+                )
+                return
+
+            # no more common reasons to skip input, so go to processing
+            if state is State.ON:
+                self.updateState(State.PROCESSING)
+                self.log_status_info("Processing data")
+
+            correction_cell_num = self._schema_cache["constantParameters.memoryCells"]
+            cell_table_max = np.max(cell_table)
+            if cell_table_max >= correction_cell_num:
+                self.log_status_info(
+                    f"Max cell ID ({cell_table_max}) exceeds range for loaded "
+                    f"constants ({correction_cell_num} cells). Some frames will not be "
+                    "corrected."
+                )
+
+            image_data = data_hash.get("image.data")
+            if image_data.shape[0] != self._schema_cache["dataFormat.memoryCells"]:
+                self.log_status_info(
+                    f"Updating new input shape {image_data.shape}, updating buffers"
+                )
+                self.set("dataFormat.memoryCells", image_data.shape[0])
+                with self._buffer_lock:
+                    # TODO: pulse filter update after reimplementation
+                    self._update_shapes()
+
+            # DataAggregator typically tells us the wrong axis order
+            if self._schema_cache["dataFormat.overrideInputAxisOrder"]:
+                expected_shape = self.input_data_shape
+                if expected_shape != image_data.shape:
+                    image_data.shape = expected_shape
+
+            do_generate_preview = (
+                train_id % self._schema_cache["preview.trainIdModulo"] == 0
+                and self._schema_cache["preview.enable"]
+            )
+
+            with self._buffer_lock:
+                self.process_data(
+                    data_hash,
+                    metadata,
+                    source,
+                    train_id,
+                    image_data,
+                    cell_table,
+                    do_generate_preview,
+                )
+            self._buffered_status_update.set("trainId", train_id)
+            self._processing_time_ema.update(
+                default_timer() - self._last_processing_started
+            )
+            self._rate_tracker.update()
+
     def _update_rate_and_state(self):
         self._buffered_status_update.set("performance.rate", self._rate_tracker.get())
         self._buffered_status_update.set(
-            "performance.processingDuration", self._processing_time_ema.get() * 1000
+            "performance.processingTime", self._processing_time_ema.get() * 1000
         )
         # trainId should be set on _buffered_status_update in input handler
         self.set(self._buffered_status_update)
 
-        if (
-            timeit.default_timer() - self._last_processing_started
-            > PROCESSING_STATE_TIMEOUT
-        ):
+        if default_timer() - self._last_processing_started > PROCESSING_STATE_TIMEOUT:
             if self.get("state") is State.PROCESSING:
                 self.updateState(State.ON)
                 self.log_status_info(
-- 
GitLab