From d73df29b36512bea61cb435218756ec49cd8adc9 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Wed, 8 Mar 2023 13:29:38 +0100
Subject: [PATCH] Handle multi-pulse data.

---
 pes_to_spec/model.py | 19 +++++++++++--------
 1 file changed, 11 insertions(+), 8 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 695dc68..ac27237 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -478,7 +478,7 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
         self.mean = dict()
         self.std = dict()
 
-    def transform(self, X: Dict[str, np.ndarray], keep_dictionary_structure: bool=False) -> np.ndarray:
+    def transform(self, X: Dict[str, np.ndarray], keep_dictionary_structure: bool=False, pulse_spacing: List[int]=[0]) -> np.ndarray:
         """
         Get a dictionary with the channel names for the inut low resolution data and output
         only the relevant input data in an array.
@@ -487,8 +487,9 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
           X: Dictionary with keys named channel_{i}_{k},
              where i is a number between 1 and 4 and k is a letter between A and D.
           keep_dictionary_structure: Whether to concatenate all channels, or keep them as a dictionary.
+          pulse_spacing: Distances between pulses in multi-pulse data. If there is only one pulse, set it to a list containing only the element zero.
 
-        Returns: Concatenated and pre-processed low-resolution data of shape (train_id, features).
+        Returns: Concatenated and pre-processed low-resolution data of shape (train_id, pulse_id, features).
         """
         if self.tof_start is None:
             raise NotImplementedError("The low-resolution data cannot be transformed before the prompt has been identified. Call the fit function first.")
@@ -496,12 +497,13 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
         if self.delta_tof is not None:
             first = max(0, self.tof_start - self.delta_tof)
             last = min(X[self.channels[0]].shape[1], self.tof_start + self.delta_tof)
-            y = {channel: item[:, first:last] for channel, item in X.items()}
+            y = {channel: np.stack([item[:, (first + delta):(last + delta)] for delta in pulse_spacing], axis=1)
+                 for channel, item in X.items()}
         if not keep_dictionary_structure:
             selected = list(y.values())
             if self.poly:
                 selected += [np.sqrt(np.fabs(v)) for v in y.values()]
-            return np.concatenate(selected, axis=1)
+            return np.concatenate(selected, axis=-1)
         return y
 
     def estimate_prompt_peak(self, X: Dict[str, np.ndarray]) -> int:
@@ -992,7 +994,7 @@ class Model(TransformerMixin, BaseEstimator):
         #fig.savefig("tmp.png")
         #plt.close(fig)
 
-        Hmod = np.real(np.absolute(H))
+        #Hmod = np.real(np.absolute(H))
         Gdir = np.fft.fftshift(np.fft.ifft(G))
         self.wiener_filter = Gdir
         self.wiener_filter_ft = G
@@ -1005,7 +1007,7 @@ class Model(TransformerMixin, BaseEstimator):
         energy_mu = np.sum(e_axis*hmod)/np.sum(hmod)
         energy_var = np.sum(((e_axis - energy_mu)**2)*hmod)/np.sum(hmod)
         self.resolution = np.sqrt(energy_var)
-        print("Resolution:", self.resolution)
+        #print("Resolution:", self.resolution)
 
         # get intensity effect
         intensity = np.sum(z, axis=1)
@@ -1069,13 +1071,14 @@ class Model(TransformerMixin, BaseEstimator):
         """Get KDE for the predicted intensity."""
         return self.kde_intensity
 
-    def predict(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
+    def predict(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: List[int]=[0]) -> Dict[str, np.ndarray]:
         """
         Predict a high-resolution spectrum from a low resolution given one.
         The output includes the uncertainty in its second and third entries of the first dimension.
 
         Args:
           low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
+          pulse_spacing: Where each pulse starts, relative to the first pulse (which is at 0). For single pulse data, set this to [0].
 
         Returns: High resolution data with shape (train_id, energy channel) in a dictionary containing
                  the expected prediction in key "expected", the stat. uncertainty in key "unc" and
@@ -1086,7 +1089,7 @@ class Model(TransformerMixin, BaseEstimator):
         #t += [time_ns()*1e-9]
         #n += ["Initial"]
 
-        low_res_pre = self.x_select.transform(low_res_data)
+        low_res_pre = self.x_select.transform(low_res_data, pulse_spacing=pulse_spacing)
         #t += [time_ns()*1e-9]
         #n += ["Select"]
 
-- 
GitLab