diff --git a/pes_to_spec/__init__.py b/pes_to_spec/__init__.py
index 5f570b94cfed4d61f71bb2662183295df81f4978..1f1123e7fdf98c4a55066351cbb13b8a9b6162ff 100644
--- a/pes_to_spec/__init__.py
+++ b/pes_to_spec/__init__.py
@@ -2,4 +2,4 @@
 Estimate high-resolution photon spectrometer data from low-resolution non-invasive measurements.
 """
 
-VERSION = "0.2.2"
+VERSION = "0.2.3"
diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index d04881b96351fa87407b546f54686c2eb0363644..2e1d360af294006dac8bce4253efad73ce74460c 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -478,7 +478,7 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
         self.mean = dict()
         self.std = dict()
 
-    def transform(self, X: Dict[str, np.ndarray], keep_dictionary_structure: bool=False) -> np.ndarray:
+    def transform(self, X: Dict[str, np.ndarray], keep_dictionary_structure: bool=False, pulse_spacing: Optional[Dict[str, List[int]]]=None) -> np.ndarray:
         """
         Get a dictionary with the channel names for the inut low resolution data and output
         only the relevant input data in an array.
@@ -487,21 +487,25 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
           X: Dictionary with keys named channel_{i}_{k},
              where i is a number between 1 and 4 and k is a letter between A and D.
           keep_dictionary_structure: Whether to concatenate all channels, or keep them as a dictionary.
+          pulse_spacing: Distances between pulses in multi-pulse data. If there is only one pulse, set it to a list containing only the element zero.
 
-        Returns: Concatenated and pre-processed low-resolution data of shape (train_id, features).
+        Returns: Concatenated and pre-processed low-resolution data of shape (train_id, pulse_id, features).
         """
         if self.tof_start is None:
             raise NotImplementedError("The low-resolution data cannot be transformed before the prompt has been identified. Call the fit function first.")
+        if pulse_spacing is None:
+            pulse_spacing = {ch: [0] for ch in X.keys()}
         y = X
         if self.delta_tof is not None:
             first = max(0, self.tof_start - self.delta_tof)
             last = min(X[self.channels[0]].shape[1], self.tof_start + self.delta_tof)
-            y = {channel: item[:, first:last] for channel, item in X.items()}
+            y = {channel: np.stack([item[:, (first + delta):(last + delta)] for delta in pulse_spacing[channel]], axis=1)
+                 for channel, item in X.items()}
         if not keep_dictionary_structure:
             selected = list(y.values())
             if self.poly:
                 selected += [np.sqrt(np.fabs(v)) for v in y.values()]
-            return np.concatenate(selected, axis=1)
+            return np.concatenate(selected, axis=-1)
         return y
 
     def estimate_prompt_peak(self, X: Dict[str, np.ndarray]) -> int:
@@ -645,18 +649,6 @@ class MultiOutputWithStd(MetaEstimatorMixin, BaseEstimator):
         y_std = np.sqrt(sigmas_squared_data + self.fast_inv_alpha)
         return y, y_std
 
-        #n_jobs = self.n_jobs
-        #y = Parallel(n_jobs=n_jobs, prefer="threads")(
-        #    delayed(e.predict)(X, return_std) for e in self.estimators_
-        #    #delayed(e.predict)(X) for e in self.estimators_
-        #)
-        #if return_std:
-        #    y, unc = zip(*y)
-        #    return np.asarray(y).T, np.asarray(unc).T
-
-        #return np.asarray(y).T
-
-
 class UncorrelatedDeviation(OutlierMixin, BaseEstimator):
     """
     Detect outliers from uncorrelated inputs.
@@ -900,6 +892,10 @@ class Model(TransformerMixin, BaseEstimator):
             weights = np.ones(high_res_data.shape[0])
         print("Fitting PCA on low-resolution data.")
         low_res_select = self.x_select.fit_transform(low_res_data)
+        # keep the number of pulses
+        B, P, _ = low_res_select.shape
+        low_res_select = low_res_select.reshape((B*P, -1))
+
         n_components = min(self.x_model["pca"].n_components, low_res_select.shape[0])
         self.x_model.set_params(pca__n_components=n_components)
         x_t = self.x_model.fit_transform(low_res_select)
@@ -1004,7 +1000,7 @@ class Model(TransformerMixin, BaseEstimator):
         #fig.savefig("tmp.png")
         #plt.close(fig)
 
-        Hmod = np.real(np.absolute(H))
+        #Hmod = np.real(np.absolute(H))
         Gdir = np.fft.fftshift(np.fft.ifft(G))
         self.wiener_filter = Gdir
         self.wiener_filter_ft = G
@@ -1017,7 +1013,7 @@ class Model(TransformerMixin, BaseEstimator):
         energy_mu = np.sum(e_axis*hmod)/np.sum(hmod)
         energy_var = np.sum(((e_axis - energy_mu)**2)*hmod)/np.sum(hmod)
         self.resolution = np.sqrt(energy_var)
-        print("Resolution:", self.resolution)
+        #print("Resolution:", self.resolution)
 
         # get intensity effect
         intensity = np.sum(z, axis=1)
@@ -1029,53 +1025,54 @@ class Model(TransformerMixin, BaseEstimator):
         selection_model = self.x_select
         low_res_selected = selection_model.transform(low_res_data, keep_dictionary_structure=True)
         for channel in self.get_channels():
+            B, P, _ = low_res_selected[channel].shape
             print(f"Calculate PCA on {channel}")
-            low_pca = self.channel_pca[channel].fit_transform(low_res_selected[channel])
+            low_pca = self.channel_pca[channel].fit_transform(low_res_selected[channel].reshape(B*P, -1))
             self.ood[channel].fit(low_pca)
 
         print("End of fit.")
 
-        return high_res
+        return high_res.reshape((B, P, -1))
 
-    def check_compatibility_per_channel(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
+    def check_compatibility_per_channel(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: Optional[Dict[str, List[int]]]=None) -> Dict[str, np.ndarray]:
         """
         Check if a new low-resolution data source is compatible with the one used in training, by
         comparing the effect of the trained PCA model on it, but do it per channel.
 
         Args:
           low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
+          pulse_spacing: The pulse spacing in multi-pulse data.
 
         Returns: Outlier score. If it is bigger than 0, this is an outlier.
         """
         selection_model = self.x_select
         channels = list(low_res_data.keys())
         # check if each channel is close to the mean
-        low_res_selected = selection_model.transform(low_res_data, keep_dictionary_structure=True)
+        low_res_selected = selection_model.transform(low_res_data, keep_dictionary_structure=True, pulse_spacing=pulse_spacing)
 
         def is_inlier(in_data, ch: str) -> np.ndarray:
-            data_pca = self.channel_pca[ch].transform(in_data)
-            return self.ood[ch].predict(data_pca)
-
-        #result = Parallel(n_jobs=-1)(
-        #    delayed(is_inlier)(low_res_selected[ch], ch) for ch in channels
-        #)
-        #result = dict(result)
-        return {ch: is_inlier(low_res_selected[ch], ch) for ch in channels}
+            B, P, _ = in_data.shape
+            data_pca = self.channel_pca[ch].transform(in_data.reshape((B*P, -1)))
+            return self.ood[ch].predict(data_pca).reshape((B, P))
+        result = {ch: is_inlier(low_res_selected[ch], ch) for ch in channels}
+        return result
 
-    def check_compatibility(self, low_res_data: Dict[str, np.ndarray]) -> np.ndarray:
+    def check_compatibility(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: Optional[Dict[str, List[int]]]=None) -> np.ndarray:
         """
         Check if a new low-resolution data source is compatible with the one used in training, by
         using a robust covariance matrix estimate of the data
 
         Args:
           low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
+          pulse_spacing: The pulse spacing in multi-pulse data.
 
         Returns: An outlier score: if it is greater than 0, this is an outlier.
         """
-        low_res = self.x_select.transform(low_res_data)
+        low_res = self.x_select.transform(low_res_data, pulse_spacing=pulse_spacing)
+        B, P, _ = low_res.shape
         pca_model = self.x_model
-        low_pca = pca_model.transform(low_res)
-        return self.ood['full'].predict(low_pca)
+        low_pca = pca_model.transform(low_res.reshape((B*P, -1)))
+        return self.ood['full'].predict(low_pca).reshape((B, P))
 
     def xgm_profile(self) -> gaussian_kde:
         """Get KDE for the XGM intensity."""
@@ -1085,13 +1082,14 @@ class Model(TransformerMixin, BaseEstimator):
         """Get KDE for the predicted intensity."""
         return self.kde_intensity
 
-    def predict(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
+    def predict(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: Optional[Dict[str, List[int]]]=None) -> Dict[str, np.ndarray]:
         """
         Predict a high-resolution spectrum from a low resolution given one.
         The output includes the uncertainty in its second and third entries of the first dimension.
 
         Args:
           low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
+          pulse_spacing: Where each pulse starts, relative to the first pulse (which is at 0). For single pulse data, set this to [0].
 
         Returns: High resolution data with shape (train_id, energy channel) in a dictionary containing
                  the expected prediction in key "expected", the stat. uncertainty in key "unc" and
@@ -1102,7 +1100,9 @@ class Model(TransformerMixin, BaseEstimator):
         #t += [time_ns()*1e-9]
         #n += ["Initial"]
 
-        low_res_pre = self.x_select.transform(low_res_data)
+        low_res_pre = self.x_select.transform(low_res_data, pulse_spacing=pulse_spacing)
+        B, P, _ = low_res_pre.shape
+        low_res_pre = low_res_pre.reshape((B*P, -1))
         #t += [time_ns()*1e-9]
         #n += ["Select"]
 
@@ -1138,10 +1138,10 @@ class Model(TransformerMixin, BaseEstimator):
         #print("Times")
         #print(dict(zip(n, t)))
 
-        return dict(expected=expected,
-                    unc=unc,
+        return dict(expected=expected.reshape((B, P, -1)),
+                    unc=unc.reshape((B, P, -1)),
                     pca=pca_unc,
-                    total_unc=total_unc,
+                    total_unc=total_unc.reshape((B, P, -1)),
                     )
 
     def deconvolve(self, expected: np.ndarray) -> np.ndarray:
@@ -1153,7 +1153,12 @@ class Model(TransformerMixin, BaseEstimator):
 
         Returns: The Wiener filter-corrected spectrum.
         """
-        return np.real(np.absolute(np.fft.ifft(np.fft.fft(expected, axis=1) * self.wiener_filter_ft.reshape(1, -1))))
+        W = self.wiener_filter_ft
+        if len(expected.shape) == 2:
+            W = W.reshape(1, -1)
+        elif len(expected.shape) == 3:
+            W = W.reshape(1, 1, -1)
+        return np.real(np.absolute(np.fft.ifft(np.fft.fft(expected, axis=-1) * W, axis=-1)))
 
     def save(self, filename: str):
         """
diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index 0b3f2fc50a6fe68c2b4f0b6dec81acdb9b0e8d65..34ba12e303534f51bc1e161e24c3472b2a9e398d 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -300,7 +300,7 @@ def main():
         # chi2 w.r.t XGM intensity
         erange = spec_raw_pe[0,-1] - spec_raw_pe[0,0]
         de = (spec_raw_pe[0,1] - spec_raw_pe[0,0])
-        chi2 = np.sum((spec_smooth - spec_pred["expected"])**2/(spec_pred["total_unc"]**2), axis=1)
+        chi2 = np.sum((spec_smooth[:, np.newaxis, :] - spec_pred["expected"])**2/(spec_pred["total_unc"]**2), axis=(-1, -2))
         ndof = float(spec_smooth.shape[1]) - 1.0
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
@@ -369,7 +369,7 @@ def main():
         plt.close(fig)
 
         # rmse
-        rmse = np.sqrt(np.mean((spec_smooth - spec_pred["expected"])**2, axis=1))
+        rmse = np.sqrt(np.mean((spec_smooth[:, np.newaxis, :] - spec_pred["expected"])**2, axis=(-1, -2)))
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
@@ -418,7 +418,7 @@ def main():
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        sns.regplot(x=np.sum(spec_raw_int, axis=1)*de, y=np.sum(spec_pred["expected"], axis=1)*de, color='r', robust=True, ax=ax)
+        sns.regplot(x=np.sum(spec_raw_int, axis=-1)*de, y=np.sum(spec_pred["expected"], axis=(-1, -2))*de, color='r', robust=True, ax=ax)
         ax.set(title=f"",
                xlabel="SPEC (raw) integral",
                ylabel="Predicted integral",
@@ -429,7 +429,7 @@ def main():
         fig = plt.figure(figsize=(12, 8))
         gs = GridSpec(1, 1)
         ax = fig.add_subplot(gs[0, 0])
-        sns.regplot(x=np.sum(spec_pred["expected"], axis=1)*de, y=xgm_flux[:,0], color='r', robust=True, ax=ax)
+        sns.regplot(x=np.sum(spec_pred["expected"], axis=(-1, -2))*de, y=xgm_flux[:,0], color='r', robust=True, ax=ax)
         ax.set(title=f"",
                xlabel="Predicted integral",
                ylabel="XGM intensity [uJ]",
@@ -445,7 +445,7 @@ def main():
     for tid in test_tids:
         idx = np.where(tid==tids)[0][0]
         plot_result(os.path.join(args.directory, f"test_{tid}.png"),
-                   {k: item[idx, ...] if k != "pca"
+                   {k: item[idx, 0, ...] if k != "pca"
                        else item[0, ...]
                        for k, item in spec_pred.items()},
                     spec_smooth[idx, :] if showSpec else None,