From 5e57e05c7af07d9bca969f9c6399f26cff7c92a0 Mon Sep 17 00:00:00 2001 From: Egor Sobolev <egor.sobolev@xfel.eu> Date: Mon, 24 Jun 2024 15:41:13 +0200 Subject: [PATCH] Remove addon abstraction --- src/cal_tools/agipdlib.py | 42 ++++++++------ src/cal_tools/litpx_counter.py | 102 ++++++++++++++------------------- 2 files changed, 66 insertions(+), 78 deletions(-) diff --git a/src/cal_tools/agipdlib.py b/src/cal_tools/agipdlib.py index 7290f54f7..c492e7a8d 100644 --- a/src/cal_tools/agipdlib.py +++ b/src/cal_tools/agipdlib.py @@ -28,7 +28,7 @@ from cal_tools.agipdutils import ( melt_snowy_pixels, ) from cal_tools.enums import AgipdGainMode, BadPixels, SnowResolution -from cal_tools.litpx_counter import AnalysisAddon, LitPixelCounter +from cal_tools.litpx_counter import LitPixelCounter from logging import warning @@ -73,7 +73,7 @@ class AgipdCtrl: else: raise ValueError(f"No raw images found for {self.image_src}") - return ncell + return ncell def get_num_cells(self) -> int: """Read number of memory cells from fast data.""" @@ -84,7 +84,7 @@ class AgipdCtrl: # the function returns wrong value. ncell = self._get_num_cells_instr() - return ncell + return ncell def _get_acq_rate_ctrl(self) -> Optional[float]: """Get acquisition (repetition) rate from CONTROL source.""" @@ -588,6 +588,7 @@ class AgipdCorrections: # Output parameters self.compress_fields = ['gain', 'mask'] self.recast_image_fields = {} + self.write_extra_data = [] # Shared variables for data and constants self.shared_dict = [] @@ -777,12 +778,13 @@ class AgipdCorrections: instrument_channels.append(f"{agipd_base}/image") # backward compatibility END - addons = [] - for name, elem in data_dict.items(): - if isinstance(elem, AnalysisAddon): - src_name = elem.source_name(karabo_id, det_channel) - addons.append((elem, src_name)) - instrument_channels.append(f"{src_name}/{elem.channel}") + # resolve names for extra source + extra_sources = [] + for name in self.write_extra_data: + processor = data_dict[name] + src_name = processor.source_name(karabo_id, det_channel) + extra_sources.append((processor, src_name)) + instrument_channels.append(f"{src_name}/{processor.channel}") with DataFile.from_details(out_folder, agg, runno, seqno) as outfile: outfile.create_metadata( @@ -823,13 +825,13 @@ class AgipdCorrections: field, shape=arr.shape, dtype=arr.dtype, **kw ) - # create addon sources - required_addon_data = [] - for addon, src_name in addons: + # create extra sources + required_data = [] + for processor, src_name in extra_sources: src = outfile.create_instrument_source(src_name) - addon.set_num_images(n_img) - addon.create_schema(src, trains, count) - required_addon_data.append((addon, src)) + processor.set_num_images(n_img) + processor.create_schema(src, trains, count) + required_data.append((processor, src)) # Write the corrected data for field in image_fields: @@ -840,9 +842,9 @@ class AgipdCorrections: else: image_grp[field][:] = data_dict[field][:n_img] - # write addon data - for addon, src in required_addon_data: - addon.write(src) + # write extra data + for processor, src in required_data: + processor.write(src) def _write_compressed_frames(self, dataset, arr): """Compress gain/mask frames in multiple threads, and save their data @@ -1493,6 +1495,7 @@ class AgipdCorrections: :param shape: Shape of expected data (nImg, x, y) :param n_cores_files: Number of files, handled in parallel """ + self.write_extra_data = [] self.shared_dict = [] for i in range(n_cores_files): self.shared_dict.append({}) @@ -1521,6 +1524,9 @@ class AgipdCorrections: self.shared_dict[i]["litpx_counter"] = LitPixelCounter( self.shared_dict[i], threshold=self.litpx_threshold) + if self.corr_bools.get("count_lit_pixels"): + self.write_extra_data.append("litpx_counter") + if self.corr_bools.get("round_photons"): self.shared_hist_preround = sharedmem.empty(len(self.hist_bins_preround) - 1, dtype="i8") self.shared_hist_postround = sharedmem.empty(len(self.hist_bins_postround) - 1, dtype="i8") diff --git a/src/cal_tools/litpx_counter.py b/src/cal_tools/litpx_counter.py index aec381bf0..680292444 100644 --- a/src/cal_tools/litpx_counter.py +++ b/src/cal_tools/litpx_counter.py @@ -2,16 +2,30 @@ import numpy as np import sharedmem -class AnalysisAddon: - """Base class for analysis addons""" +class LitPixelCounter: + """Lit-pixel counter: counts pixels with a signal above a threshold.""" - channel = "data" + channel = "litpx" output_fields = [ - "cellId", "pulseId", "trainId"] + "cellId", "pulseId", "trainId", + "litPixels", "unmaskedPixels", "totalIntensity" + ] required_data = [ - "cellId", "pulseId", "trainId"] + "data", "mask", "cellId", "pulseId", "trainId" + ] - def __init__(self, data): + def __init__(self, data, threshold=0.7): + """Initialize the instance of lit-pixel analysis addon. + + Parameters + ---------- + data: dict + Dictionary with image data. It must include the fields: + `data`, `mask`, `cellId`, `pulseId`, `trainId` + threshold: float + The pixel intensity value, which if exceeded means + that the pixel is illuminated. + """ required_data = set(self.required_data) | {"pulseId"} for name in required_data: if name not in data: @@ -21,6 +35,19 @@ class AnalysisAddon: self.max_images = data["pulseId"].shape[0] self.num_images = self.max_images + # specific members + self.image = data["data"] + self.mask = data["mask"] + + self.threshold = threshold + self.num_unmasked_px = sharedmem.full(self.max_images, 0, int) + self.num_lit_px = sharedmem.full(self.max_images, 0, int) + self.total_intensity = sharedmem.full(self.max_images, 0, int) + + self.data["litPixels"] = self.num_lit_px + self.data["unmaskedPixels"] = self.num_unmasked_px + self.data["totalIntensity"] = self.total_intensity + def set_num_images(self, num_images): """Sets the actual number of images in data. @@ -39,7 +66,14 @@ class AnalysisAddon: chunk: slice, sequence or array The indices of images in `data` to process """ - raise NotImplementedError + ix = range(*chunk.indices(self.num_images)) + for i in ix: + mask = self.mask[i] == 0 + self.total_intensity[i] = np.sum( + self.image[i], initial=0, where=mask) + self.num_lit_px[i] = np.sum( + self.image[i] > self.threshold, initial=0, where=mask) + self.num_unmasked_px[i] = np.sum(mask) def source_name(self, karabo_id, channel): """Returns the source name. @@ -56,7 +90,7 @@ class AnalysisAddon: source_name: str The source name for EXDF files """ - return f"{karabo_id}/DANA/{channel}" + return f"{karabo_id}/LITPX/{channel}" def create_schema(self, source, file_trains=None, count=None): """Creates the indices and keys in the source. @@ -106,55 +140,3 @@ class AnalysisAddon: channel = source[self.channel] for key in self.output_fields: channel[key][:] = self.data[key][:self.num_images] - - -class LitPixelCounter(AnalysisAddon): - """Lit-pixel counter analysis addon.""" - - channel = "litpx" - output_fields = [ - "cellId", "pulseId", "trainId", - "litPixels", "unmaskedPixels", "totalIntensity" - ] - required_data = [ - "data", "mask", "cellId", "pulseId", "trainId" - ] - - def __init__(self, data, threshold=0.8): - """Initialize the instance of lit-pixel analysis addon. - - Parameters - ---------- - data: dict - Dictionary with image data. It must include the fields: - `data`, `mask`, `cellId`, `pulseId`, `trainId` - threshold: float - The pixel intensity value, which if exceeded means - that the pixel is illuminated. - """ - super().__init__(data) - - self.image = data["data"] - self.mask = data["mask"] - - self.threshold = threshold - self.num_unmasked_px = sharedmem.full(self.max_images, 0, int) - self.num_lit_px = sharedmem.full(self.max_images, 0, int) - self.total_intensity = sharedmem.full(self.max_images, 0, int) - - self.data["litPixels"] = self.num_lit_px - self.data["unmaskedPixels"] = self.num_unmasked_px - self.data["totalIntensity"] = self.total_intensity - - def process(self, chunk): - ix = range(*chunk.indices(self.num_images)) - for i in ix: - mask = self.mask[i] == 0 - self.total_intensity[i] = np.sum( - self.image[i], initial=0, where=mask) - self.num_lit_px[i] = np.sum( - self.image[i] > self.threshold, initial=0, where=mask) - self.num_unmasked_px[i] = np.sum(mask) - - def source_name(self, karabo_id, channel): - return f"{karabo_id}/LITPX/{channel}" -- GitLab