diff --git a/src/cal_tools/agipdlib.py b/src/cal_tools/agipdlib.py index e935a3c84d4f6651e645b9b430894004399a440f..2c3201a796ec4a9a6b1e72da4e10d9873a97a838 100644 --- a/src/cal_tools/agipdlib.py +++ b/src/cal_tools/agipdlib.py @@ -28,7 +28,6 @@ from cal_tools.agipdutils import ( melt_snowy_pixels, ) from cal_tools.enums import AgipdGainMode, BadPixels, SnowResolution -from cal_tools.litpx_counter import LitPixelCounter from logging import warning @@ -588,7 +587,6 @@ class AgipdCorrections: # Output parameters self.compress_fields = ['gain', 'mask'] self.recast_image_fields = {} - self.write_extra_data = [] # Shared variables for data and constants self.shared_dict = [] @@ -778,13 +776,15 @@ class AgipdCorrections: instrument_channels.append(f"{agipd_base}/image") # backward compatibility END - # resolve names for extra source - extra_sources = [] - for name in self.write_extra_data: - processor = data_dict[name] - src_name = processor.source_name(karabo_id, det_channel) - extra_sources.append((processor, src_name)) - instrument_channels.append(f"{src_name}/{processor.channel}") + # create metadata: lit-pixel counter + if self.corr_bools.get("count_lit_pixels"): + litpx_output_fields = [ + "cellId", "pulseId", "trainId", + "litPixels", "unmaskedPixels", "totalIntensity" + ] + litpx_index_group = "litpx" + litpx_src_name = f"{karabo_id}/LITPX/{det_channel}" + instrument_channels.append(f"{litpx_src_name}/{litpx_index_group}") with DataFile.from_details(out_folder, agg, runno, seqno) as outfile: outfile.create_metadata( @@ -825,13 +825,14 @@ class AgipdCorrections: field, shape=arr.shape, dtype=arr.dtype, **kw ) - # create extra sources - required_data = [] - for processor, src_name in extra_sources: - src = outfile.require_instrument_source(src_name) - processor.set_num_images(n_img) - processor.create_schema(src, trains, count) - required_data.append((processor, src)) + # create schema: lit-pixel counter + if self.corr_bools.get("count_lit_pixels"): + litpx_src = outfile.require_instrument_source(litpx_src_name) + litpx_src.create_index(**{litpx_index_group: count}) + litpx_group = litpx_src.require_group(litpx_index_group) + for field in litpx_output_fields: + litpx_group.create_dataset( + field, shape=(n_img,), dtype=data_dict[field].dtype) # Write the corrected data for field in image_fields: @@ -842,9 +843,10 @@ class AgipdCorrections: else: image_grp[field][:] = data_dict[field][:n_img] - # write extra data - for processor, src in required_data: - processor.write(src) + # write data: lit-pixel counter + if self.corr_bools.get("count_lit_pixels"): + for field in litpx_output_fields: + litpx_group[field][:] = data_dict[field][:n_img] def _write_compressed_frames(self, dataset, arr): """Compress gain/mask frames in multiple threads, and save their data @@ -1495,7 +1497,6 @@ class AgipdCorrections: :param shape: Shape of expected data (nImg, x, y) :param n_cores_files: Number of files, handled in parallel """ - self.write_extra_data = [] self.shared_dict = [] for i in range(n_cores_files): self.shared_dict.append({}) @@ -1521,21 +1522,30 @@ class AgipdCorrections: self.shared_dict[i]["nimg_in_trains"] = sharedmem.empty(1024, dtype="i8") # noqa if self.corr_bools.get("count_lit_pixels"): - self.shared_dict[i]["litpx_counter"] = LitPixelCounter( - self.shared_dict[i], threshold=self.litpx_threshold) - - if self.corr_bools.get("count_lit_pixels"): - self.write_extra_data.append("litpx_counter") + self.shared_dict[i]["unmaskedPixels"] = sharedmem.full(shape[0], 0, int) + self.shared_dict[i]["litPixels"] = sharedmem.full(shape[0], 0, int) + self.shared_dict[i]["totalIntensity"] = sharedmem.full(shape[0], 0, np.float32) if self.corr_bools.get("round_photons"): self.shared_hist_preround = sharedmem.empty(len(self.hist_bins_preround) - 1, dtype="i8") self.shared_hist_postround = sharedmem.empty(len(self.hist_bins_postround) - 1, dtype="i8") def litpixel_counter(self, i_proc: int, first: int, last: int): - counter = self.shared_dict[i_proc].get("litpx_counter") - if counter: - counter.set_num_images(self.shared_dict[i_proc]["nImg"][0]) - counter.process(slice(first, last)) + """Lit-pixel counter: counts pixels with a signal above a threshold. + + :param i_proc: Index of shared memory array to process + :param first: Index of the first image to be corrected + :param last: Index of the last image to be corrected + """ + data = self.shared_dict[i_proc] + + for i in range(first, last): + mask = data["mask"][i] == 0 + image = data["data"][i] + + data["totalIntensity"][i] = np.sum(image, initial=0, where=mask) + data["litPixels"][i] = np.sum(image > self.litpx_threshold, initial=0, where=mask) + data["unmaskedPixels"][i] = np.sum(mask) def validate_selected_pulses( diff --git a/src/cal_tools/litpx_counter.py b/src/cal_tools/litpx_counter.py deleted file mode 100644 index dcd30a126d0c714dc3b506bd1697b90820a278f3..0000000000000000000000000000000000000000 --- a/src/cal_tools/litpx_counter.py +++ /dev/null @@ -1,142 +0,0 @@ -import numpy as np -import sharedmem - - -class LitPixelCounter: - """Lit-pixel counter: counts pixels with a signal above a threshold.""" - - channel = "litpx" - output_fields = [ - "cellId", "pulseId", "trainId", - "litPixels", "unmaskedPixels", "totalIntensity" - ] - required_data = [ - "data", "mask", "cellId", "pulseId", "trainId" - ] - - def __init__(self, data, threshold=0.7): - """Initialize the instance of lit-pixel analysis addon. - - Parameters - ---------- - data: dict - Dictionary with image data. It must include the fields: - `data`, `mask`, `cellId`, `pulseId`, `trainId` - threshold: float - The pixel intensity value, which if exceeded means - that the pixel is illuminated. - """ - required_data = set(self.required_data) | {"pulseId"} - for name in required_data: - if name not in data: - raise ValueError(f"The field '{name}' is missed in 'data'") - - self.data = data.copy() - self.max_images = data["pulseId"].shape[0] - self.num_images = self.max_images - - # specific members - self.image = data["data"] - self.mask = data["mask"] - - self.threshold = threshold - self.num_unmasked_px = sharedmem.full(self.max_images, 0, int) - self.num_lit_px = sharedmem.full(self.max_images, 0, int) - self.total_intensity = sharedmem.full(self.max_images, 0, np.float32) - - self.data["litPixels"] = self.num_lit_px - self.data["unmaskedPixels"] = self.num_unmasked_px - self.data["totalIntensity"] = self.total_intensity - - def set_num_images(self, num_images): - """Sets the actual number of images in data. - - Parameters - ---------- - num_images: int - The actual number of images in data - """ - self.num_images = num_images - - def process(self, chunk): - """Processes the image data. - - Parameters - ---------- - chunk: slice, sequence or array - The indices of images in `data` to process - """ - ix = range(*chunk.indices(self.num_images)) - for i in ix: - mask = self.mask[i] == 0 - self.total_intensity[i] = np.sum( - self.image[i], initial=0, where=mask) - self.num_lit_px[i] = np.sum( - self.image[i] > self.threshold, initial=0, where=mask) - self.num_unmasked_px[i] = np.sum(mask) - - def source_name(self, karabo_id, channel): - """Returns the source name. - - Parameters - ---------- - karabo_id: str - The detector Karabo Id, e.g. SPB_DET_AGIPD1M-1 - channel: str - The detector channel Id, e.g. 15CH0:output - - Returns - ------- - source_name: str - The source name for EXDF files - """ - return f"{karabo_id}/LITPX/{channel}" - - def create_schema(self, source, file_trains=None, count=None): - """Creates the indices and keys in the source. - - Parameters - ---------- - source: InstrumentSource - The source, where to create the keys - file_trains: sequence or array - The list of trains in the file - count: array - The count of entry per train for this channel. If None, - then addons create it using `trainId` field in `data`. - """ - if file_trains is None: - file_trains = source.file["INDEX/trainId"][:] - - if count is None: - tid = self.data["trainId"][:self.num_images] - trains, count = np.unique(tid, return_counts=True) - count = count[np.in1d(trains, file_trains)] - - if len(file_trains) != len(count): - raise ValueError( - "The length of data count does not match the number of trains") - if np.sum(count) != self.num_images: - raise ValueError( - "The sum of data count does not match " - "the total number of data entries") - - source.create_index(**{self.channel: count}) - for key in self.output_fields: - source.create_dataset( - f"{self.channel}/{key}", - shape=(self.num_images,), - dtype=self.data[key].dtype - ) - - def write(self, source): - """Writes data to file. - - Parameters - ---------- - source: InstrumentSource - The source, where to write data - """ - channel = source[self.channel] - for key in self.output_fields: - channel[key][:] = self.data[key][:self.num_images]