diff --git a/src/toolbox_scs/detectors/hrixs.py b/src/toolbox_scs/detectors/hrixs.py index b3a8cb4d1de3a0155f203e312bcc2b7f616017bb..761ee6e3d17b49400ab9def384679ee89328e63a 100644 --- a/src/toolbox_scs/detectors/hrixs.py +++ b/src/toolbox_scs/detectors/hrixs.py @@ -367,11 +367,21 @@ class hRIXS: return (px + mx) / 2, py - my def __add__(self, other): - ix, iy = self.spectrum(normalize=False) - jx, jy = other.spectrum(normalize=False) + images = split_images(self.images) + split_images(other.images) + return self.__class__(images=images, + norm=self.norm + other.norm) - i_n = self.norm or 0 - j_n = other.norm or 0 - norm = ((i_n + j_n) or 1) - return ix, (iy + jy) / norm +def split_images(images): + """ Split the images by the number of trains + and return the list of their views. """ + + # Check if the images has already been splitted + if isinstance(images, list): + return images + + # Check if the images is a numpy of of 2d image that follows + # the shape (num, y_dim, x_dim) + assert len(images.shape) == 3 + + return [np.squeeze(image) for image in np.vsplit(images, images.shape[0])]