diff --git a/src/cal_tools/agipdlib.py b/src/cal_tools/agipdlib.py index 262defdcbdfd1b0b3297527babf5b45900030c77..1783c0b81dd2bd37726a3503db44a3c3d8a3da34 100644 --- a/src/cal_tools/agipdlib.py +++ b/src/cal_tools/agipdlib.py @@ -28,6 +28,7 @@ from cal_tools.agipdutils import ( melt_snowy_pixels, ) from cal_tools.enums import AgipdGainMode, BadPixels, SnowResolution +from cal_tools.litpx_counter import LitPixelCounter from logging import warning @@ -1542,71 +1543,6 @@ class AgipdCorrections: counter.process(slice(first, last)) -class LitPixelCounter: - channel = "litpx" - output_fields = [ - "cellId", "pulseId", "trainId", "litPixels", "goodPixels"] - - def __init__(self, data, threshold=0.8): - self.data = data.copy() - for name in ["data", "mask", "cellId", "pulseId", "trainId"]: - assert name in data - - self.image = data["data"] - self.mask = data["mask"] - - self.threshold = threshold - self.max_images = data["data"].shape[0] - self.num_images = self.max_images - - self.num_good_px = sharedmem.full(self.max_images, 0, int) - self.num_lit_px = sharedmem.full(self.max_images, 0, int) - - self.data["litPixels"] = self.num_lit_px - self.data["goodPixels"] = self.num_good_px - - def set_num_images(self, num_images): - self.num_images = num_images - - def process(self, chunk): - ix = range(*chunk.indices(self.num_images)) - for i in ix: - mask = self.mask[i] == 0 - self.num_lit_px[i] = np.sum( - self.image[i] > self.threshold, initial=0, where=mask) - self.num_good_px[i] = np.sum(mask) - - def create_schema(self, source, file_trains=None, count=None): - if file_trains is None: - file_trains = source.file["INDEX/trainId"][:] - - if count is None: - tid = self.data["trainId"][:self.num_images] - trains, count = np.unique(tid, return_counts=True) - count = count[np.in1d(trains, file_trains)] - - if len(file_trains) != len(count): - raise ValueError( - "The length of data count does not match the number of trains") - if np.sum(count) != self.num_images: - raise ValueError( - "The sum of data count does not match " - "the total number of data entries") - - source.create_index(**{self.channel: count}) - for key in self.output_fields: - source.create_dataset( - f"{self.channel}/{key}", - shape=(self.num_images,), - dtype=self.data[key].dtype - ) - - def write(self, source): - channel = source[self.channel] - for key in self.output_fields: - channel[key][:] = self.data[key][:self.num_images] - - def validate_selected_pulses( max_pulses: List[int], max_cells: int ) -> List[int]: diff --git a/src/cal_tools/litpx_counter.py b/src/cal_tools/litpx_counter.py new file mode 100644 index 0000000000000000000000000000000000000000..660ddbf72eaea1a42dc083925d09e08e0188a650 --- /dev/null +++ b/src/cal_tools/litpx_counter.py @@ -0,0 +1,85 @@ +import numpy as np +import sharedmem + + +class AnalysisAddon: + channel = "data" + output_fields = [ + "cellId", "pulseId", "trainId"] + required_data = [ + "cellId", "pulseId", "trainId"] + + def __init__(self, data): + required_data = set(self.required_data) | {"pulseId"} + for name in required_data: + if name not in data: + raise ValueError(f"The field '{name}' is missed in 'data'") + + self.data = data.copy() + self.max_images = data["pulseId"].shape[0] + self.num_images = self.max_images + + def set_num_images(self, num_images): + self.num_images = num_images + + def process(self, chunk): + raise NotImplementedError + + def create_schema(self, source, file_trains=None, count=None): + if file_trains is None: + file_trains = source.file["INDEX/trainId"][:] + + if count is None: + tid = self.data["trainId"][:self.num_images] + trains, count = np.unique(tid, return_counts=True) + count = count[np.in1d(trains, file_trains)] + + if len(file_trains) != len(count): + raise ValueError( + "The length of data count does not match the number of trains") + if np.sum(count) != self.num_images: + raise ValueError( + "The sum of data count does not match " + "the total number of data entries") + + source.create_index(**{self.channel: count}) + for key in self.output_fields: + source.create_dataset( + f"{self.channel}/{key}", + shape=(self.num_images,), + dtype=self.data[key].dtype + ) + + def write(self, source): + channel = source[self.channel] + for key in self.output_fields: + channel[key][:] = self.data[key][:self.num_images] + + +class LitPixelCounter(AnalysisAddon): + channel = "litpx" + output_fields = [ + "cellId", "pulseId", "trainId", "litPixels", "goodPixels"] + required_data = [ + "data", "mask", "cellId", "pulseId", "trainId"] + + def __init__(self, data, threshold=0.8): + super().__init__(data) + + self.image = data["data"] + self.mask = data["mask"] + + self.threshold = threshold + self.num_good_px = sharedmem.full(self.max_images, 0, int) + self.num_lit_px = sharedmem.full(self.max_images, 0, int) + + self.data["litPixels"] = self.num_lit_px + self.data["goodPixels"] = self.num_good_px + + def process(self, chunk): + ix = range(*chunk.indices(self.num_images)) + for i in ix: + mask = self.mask[i] == 0 + self.num_lit_px[i] = np.sum( + self.image[i] > self.threshold, initial=0, where=mask) + self.num_good_px[i] = np.sum(mask)