From d839b25d435d53177d089ee4a1369e055a87f4ae Mon Sep 17 00:00:00 2001 From: David Hammer <dhammer@mailbox.org> Date: Wed, 18 Oct 2023 09:33:24 +0200 Subject: [PATCH] Use xarray to assemble partial data, drop preallocated buffers --- src/calng/DetectorAssembler.py | 73 ++++++++++++++++------------------ 1 file changed, 34 insertions(+), 39 deletions(-) diff --git a/src/calng/DetectorAssembler.py b/src/calng/DetectorAssembler.py index 31def2c0..878298ce 100644 --- a/src/calng/DetectorAssembler.py +++ b/src/calng/DetectorAssembler.py @@ -22,6 +22,7 @@ from karabo.bound import ( Unit, ) import numpy as np +import xarray as xr from TrainMatcher import TrainMatcher from . import geom_utils, preview_utils, scenes, schemas, utils @@ -199,7 +200,6 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): self._image_data_path = self.get("imageDataPath") self._image_mask_path = self.get("imageMaskPath") self._geometry = None - self._stack_input_buffer = None # set up source to index mapping self._merge_source_to_index_from_regex() self._set_source_to_index_from_table() @@ -312,15 +312,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): self._geometry = geom_utils.deserialize_geometry(serialized_geometry) except Exception as e: self.log.WARN(f"Failed to deserialize geometry; {e}") - # TODO: allow multiple memory cells (extra geom notion of extra dimensions) - self._stack_input_buffer = np.ma.masked_array( - data=np.zeros(self._geometry.expected_data_shape, dtype=np.float32), - mask=False, - ) - self._assemble_buffer = np.ma.masked_array( - data=self._geometry.output_array_for_position_fast(), - mask=False, - ) + # TODO: test with multiple memory cells (extra geom notion of extra dimensions) def on_matched_data(self, train_id, sources): ts_start = default_timer() @@ -333,10 +325,12 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): self.unsafe_get("outputForBridgeOutput") ) - module_indices_unfilled = set(range(self._stack_input_buffer.shape[0])) earliest_source_timestamp = float("inf") - self._stack_input_buffer.mask.fill(False) + image_datas, image_masks, module_indices = [], [], [] for source, (data, source_timestamp) in sources.items(): + earliest_source_timestamp = min( + earliest_source_timestamp, source_timestamp.toTimestamp() + ) # regular TrainMatcher output self.output.write( data, ChannelMetaData(source, source_timestamp), copyAllData=False @@ -344,51 +338,52 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): if bridge_output_choice is BridgeOutputOptions.MATCHED: self.zmq_output.write(source, data, source_timestamp) - if source not in self._source_to_index: - continue - # prepare for assembly - # TODO: handle failure to "parse" source, get data out - if not data.has(self._image_data_path): + if ( + source not in self._source_to_index + or not data.has(self._image_data_path) + ): continue image_data = data[self._image_data_path] if isinstance(image_data, ImageData): # TODO: maybe glance encoding here image_data = image_data.getData() - image_data = image_data.astype(np.float32, copy=False) # TODO: set dtype based on input? + # TODO: set dtype based on input? + image_datas.append(image_data.astype(np.float32, copy=False)) if data.has(self._image_mask_path): image_mask = data[self._image_mask_path] if isinstance(image_mask, ImageData): image_mask = image_mask.getData() else: - image_mask = False - module_index = self._source_to_index(source) - masked = np.ma.masked_array( - data=image_data, - mask=image_mask, - ) - self._stack_input_buffer[module_index] = masked - module_indices_unfilled.discard(module_index) - earliest_source_timestamp = min( - earliest_source_timestamp, source_timestamp.toTimestamp() - ) + image_mask = np.zeros_like(image_data, dtype=np.uint8) + image_masks.append(image_mask) + module_indices.append(self._source_to_index[source]) self.output.update(safeNDArray=True) if bridge_output_choice is BridgeOutputOptions.MATCHED: self.zmq_output.update() - for module_index in module_indices_unfilled: - self._stack_input_buffer.mask[module_index].fill(True) - # consider configurable treatment of missing modules - - assembled, _ = self._geometry.position_modules_fast( - self._stack_input_buffer, out=self._assemble_buffer + dims = ["module", "slow_scan", "fast_scan"] + coords = {"module": module_indices} + assembled_data, _ = self._geometry.position_modules( + xr.DataArray( + data=image_datas, + dims=dims, + coords=coords, + ) + ) + assembled_mask, _ = self._geometry.position_modules( + xr.DataArray( + data=image_masks, + dims=dims, + coords=coords, + ) ) - assembled.mask |= (~np.isfinite(assembled.data)) + assembled_mask |= (~np.isfinite(assembled_data)) output_hash = Hash( "image.data", - assembled, + assembled_data, "image.mask", - assembled.mask, + assembled_mask, ) self.assembled_output.write( output_hash, @@ -404,7 +399,7 @@ class DetectorAssembler(TrainMatcher.TrainMatcher): self._preview_friend.write_outputs( my_timestamp, - assembled, + np.ma.masked_array(data=assembled_data, mask=assembled_mask), ) self._processing_time_tracker.update(default_timer() - ts_start) -- GitLab