From d839b25d435d53177d089ee4a1369e055a87f4ae Mon Sep 17 00:00:00 2001
From: David Hammer <dhammer@mailbox.org>
Date: Wed, 18 Oct 2023 09:33:24 +0200
Subject: [PATCH] Use xarray to assemble partial data, drop preallocated
 buffers

---
 src/calng/DetectorAssembler.py | 73 ++++++++++++++++------------------
 1 file changed, 34 insertions(+), 39 deletions(-)

diff --git a/src/calng/DetectorAssembler.py b/src/calng/DetectorAssembler.py
index 31def2c0..878298ce 100644
--- a/src/calng/DetectorAssembler.py
+++ b/src/calng/DetectorAssembler.py
@@ -22,6 +22,7 @@ from karabo.bound import (
     Unit,
 )
 import numpy as np
+import xarray as xr
 from TrainMatcher import TrainMatcher
 
 from . import geom_utils, preview_utils, scenes, schemas, utils
@@ -199,7 +200,6 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
         self._image_data_path = self.get("imageDataPath")
         self._image_mask_path = self.get("imageMaskPath")
         self._geometry = None
-        self._stack_input_buffer = None
         # set up source to index mapping
         self._merge_source_to_index_from_regex()
         self._set_source_to_index_from_table()
@@ -312,15 +312,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
             self._geometry = geom_utils.deserialize_geometry(serialized_geometry)
         except Exception as e:
             self.log.WARN(f"Failed to deserialize geometry; {e}")
-        # TODO: allow multiple memory cells (extra geom notion of extra dimensions)
-        self._stack_input_buffer = np.ma.masked_array(
-            data=np.zeros(self._geometry.expected_data_shape, dtype=np.float32),
-            mask=False,
-        )
-        self._assemble_buffer = np.ma.masked_array(
-            data=self._geometry.output_array_for_position_fast(),
-            mask=False,
-        )
+        # TODO: test with multiple memory cells (extra geom notion of extra dimensions)
 
     def on_matched_data(self, train_id, sources):
         ts_start = default_timer()
@@ -333,10 +325,12 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
             self.unsafe_get("outputForBridgeOutput")
         )
 
-        module_indices_unfilled = set(range(self._stack_input_buffer.shape[0]))
         earliest_source_timestamp = float("inf")
-        self._stack_input_buffer.mask.fill(False)
+        image_datas, image_masks, module_indices = [], [], []
         for source, (data, source_timestamp) in sources.items():
+            earliest_source_timestamp = min(
+                earliest_source_timestamp, source_timestamp.toTimestamp()
+            )
             # regular TrainMatcher output
             self.output.write(
                 data, ChannelMetaData(source, source_timestamp), copyAllData=False
@@ -344,51 +338,52 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
             if bridge_output_choice is BridgeOutputOptions.MATCHED:
                 self.zmq_output.write(source, data, source_timestamp)
 
-            if source not in self._source_to_index:
-                continue
-            # prepare for assembly
-            # TODO: handle failure to "parse" source, get data out
-            if not data.has(self._image_data_path):
+            if (
+                source not in self._source_to_index
+                or not data.has(self._image_data_path)
+            ):
                 continue
             image_data = data[self._image_data_path]
             if isinstance(image_data, ImageData):
                 # TODO: maybe glance encoding here
                 image_data = image_data.getData()
-            image_data = image_data.astype(np.float32, copy=False)  # TODO: set dtype based on input?
+            # TODO: set dtype based on input?
+            image_datas.append(image_data.astype(np.float32, copy=False))
             if data.has(self._image_mask_path):
                 image_mask = data[self._image_mask_path]
                 if isinstance(image_mask, ImageData):
                     image_mask = image_mask.getData()
             else:
-                image_mask = False
-            module_index = self._source_to_index(source)
-            masked = np.ma.masked_array(
-                data=image_data,
-                mask=image_mask,
-            )
-            self._stack_input_buffer[module_index] = masked
-            module_indices_unfilled.discard(module_index)
-            earliest_source_timestamp = min(
-                earliest_source_timestamp, source_timestamp.toTimestamp()
-            )
+                image_mask = np.zeros_like(image_data, dtype=np.uint8)
+            image_masks.append(image_mask)
+            module_indices.append(self._source_to_index[source])
 
         self.output.update(safeNDArray=True)
         if bridge_output_choice is BridgeOutputOptions.MATCHED:
             self.zmq_output.update()
 
-        for module_index in module_indices_unfilled:
-            self._stack_input_buffer.mask[module_index].fill(True)
-            # consider configurable treatment of missing modules
-
-        assembled, _ = self._geometry.position_modules_fast(
-            self._stack_input_buffer, out=self._assemble_buffer
+        dims = ["module", "slow_scan", "fast_scan"]
+        coords = {"module": module_indices}
+        assembled_data, _ = self._geometry.position_modules(
+            xr.DataArray(
+                data=image_datas,
+                dims=dims,
+                coords=coords,
+            )
+        )
+        assembled_mask, _ = self._geometry.position_modules(
+            xr.DataArray(
+                data=image_masks,
+                dims=dims,
+                coords=coords,
+            )
         )
-        assembled.mask |= (~np.isfinite(assembled.data))
+        assembled_mask |= (~np.isfinite(assembled_data))
         output_hash = Hash(
             "image.data",
-            assembled,
+            assembled_data,
             "image.mask",
-            assembled.mask,
+            assembled_mask,
         )
         self.assembled_output.write(
             output_hash,
@@ -404,7 +399,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher):
 
         self._preview_friend.write_outputs(
             my_timestamp,
-            assembled,
+            np.ma.masked_array(data=assembled_data, mask=assembled_mask),
         )
 
         self._processing_time_tracker.update(default_timer() - ts_start)
-- 
GitLab