From 564bcc68b903e33aca90c2ea07d1528b4da4961d Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Mon, 9 Oct 2023 14:14:47 +0200
Subject: [PATCH] Updated pre and postprocessing logic to get number of
 components and smoothing automatically.

---
 pes_to_spec/model.py                 | 104 +++++++++++++++++++--------
 pes_to_spec/test/offline_analysis.py |  25 +++++--
 pes_to_spec/test/prepare_plots.py    |   4 +-
 3 files changed, 97 insertions(+), 36 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 99734a3..a58e954 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -89,7 +89,7 @@ class HighResolutionSmoother(TransformerMixin, BaseEstimator):
     Smoothens out the high resolution data.
 
     Args:
-      high_res_sigma: Energy resolution in eV.
+      high_res_sigma: Energy resolution in eV. If None, guess.
     """
     def __init__(self,
                  high_res_sigma: float=0.2
@@ -610,47 +610,43 @@ class Model(TransformerMixin, BaseEstimator):
 
     Args:
       channels: Selected channels to use as an input for the low resolution data.
-      n_pca_lr: Number of low-resolution data PCA components.
-      n_pca_hr: Number of high-resolution data PCA components.
-      high_res_sigma: Resolution of the high-resolution spectrometer in electron-Volts.
+      pca_threshold: Variance threshold to keep.
+      high_res_fwhm: Resolution of the high-resolution spectrometer in electron-Volts.
       tof_start: Start looking at this index from the low-resolution spectrometer data.
                  Set to None to perform no selection
       delta_tof: Number of components to take from the low-resolution spectrometer.
                  Set to None to perform no selection.
-      validation_size: Fraction (number between 0 and 1) of the data to take for
-                       validation and systematic uncertainty estimate.
       model_type: Which model to use. "bnn" for a BNN, "bnn_rvm" for a BNN with RVM, "ridge" for Ridge and "ard" for ARD.
-      n_peaks: Minimum numbr of peaks in the grating spectrometer.
+      n_peaks: Minimum number of peaks in the grating spectrometer.
       n_bnn_epochs: Number of BNN epochs for training.
 
     """
     def __init__(self,
                  channels:List[str]=[f"channel_{j}_{k}"
                                      for j, k in product(range(1, 5), ["A", "B", "C", "D"])],
-                 n_pca_lr: int=600,
-                 n_pca_hr: int=20,
-                 high_res_sigma: float=0.20,
+                 pca_threshold: float=0.90,
+                 high_res_fwhm: float=0,
                  tof_start: Optional[int]=None,
                  delta_tof: Optional[int]=300,
-                 validation_size: float=0.05,
                  model_type: Literal["bnn", "bnn_rvm", "ridge", "ard"]="ard",
                  n_peaks: int=0,
                  n_bnn_epochs: int=500,
                 ):
         if model_type in ["bnn", "bnn_rvm"] and not is_bnn_available:
             raise MethodNotAvailableException("The BNN model requires a PyTorch installation. Please do `pip install torch` or `conda install pytorch` to be able to use the BNN model.")
-        self.high_res_sigma = high_res_sigma
+        self.pca_threshold = pca_threshold
+        self.high_res_fwhm = high_res_fwhm
         # models
         self.x_select = SelectRelevantLowResolution(channels, tof_start, delta_tof, poly=False) #(model_type not in ["bnn", "bnn_rvm"]))
         x_model_steps = list()
         x_model_steps += [
-                          ('pca', PCA(n_pca_lr, whiten=True)),
+                          ('pca', PCA(None, whiten=True)),
                           ('unc', UncertaintyHolder()),
                           ]
         self.x_model = Pipeline(x_model_steps)
         self.y_model = Pipeline([
-                                ('smoothen', HighResolutionSmoother(high_res_sigma)),
-                                ('pca', PCA(n_pca_hr, whiten=True)),
+                                ('smoothen', HighResolutionSmoother(high_res_fwhm)),
+                                ('pca', PCA(None, whiten=True)),
                                 ('unc', UncertaintyHolder()),
                                 ])
         self.ood = {ch: UncorrelatedDeviation(sigma=5)
@@ -675,30 +671,34 @@ class Model(TransformerMixin, BaseEstimator):
         self.channel_pca = {ch: IncrementalPCA(n_pca_lr_per_channel, whiten=True)
                             for ch in channels}
 
-        # size of the test subset
-        self.validation_size = validation_size
-
         # minimum number of peaks
         self.n_peaks = n_peaks
 
         # other characteristics for inspection and validation to be set in self.fit(...)
+        self.auto_corr_virt = None
+        self.auto_corr_hr = None
+        self.fwhm_hr = None
+        self.fwhm_virt = None
         self.resolution = None
+
         self.wiener_filter = None
         self.wiener_filter_ft = None
         self.wiener_energy = None
         self.wiener_energy_ft = None
         self.transfer_function = None
         self.impulse_response = None
-        self.auto_corr = None
 
         self.extra_options = ["mu_xgm", "sigma_xgm",
                               "wiener_filter_ft", "wiener_filter",
                               "wiener_energy_ft", "wiener_energy",
-                              "resolution",
+                              "resolution", "fwhm_virt", "fwhm_hr",
                               "transfer_function", "impulse_response",
-                              "auto_corr",
+                              "auto_corr_virt", "auto_corr_hr",
                               "model_type",
-                              "n_obs"]
+                              "n_obs",
+                              "pca_threshold",
+                              "high_res_fwhm",
+                              ]
 
     def n_pars(self) -> float:
         """Get number of parameters."""
@@ -793,6 +793,7 @@ class Model(TransformerMixin, BaseEstimator):
             high_res_data: np.ndarray, high_res_photon_energy: np.ndarray,
             weights: Optional[np.ndarray]=None,
             pulse_energy: Optional[np.ndarray]=None,
+            ood: bool=True
             ) -> np.ndarray:
         """
         Train the model.
@@ -805,6 +806,9 @@ class Model(TransformerMixin, BaseEstimator):
           high_res_data: Reference high resolution data with a one-to-one match to the
                          low resolution data in the train_id dimension. Shape (train_id, ToF channel).
           high_res_photon_energy: Photon energy axis for the high-resolution data.
+          weights: If set, use them to weigh the data.
+          pulse_energy: XGM intensity.
+          ood: Whether to fit out-of-sample detection to test data compatibility later.
 
         Returns: Smoothened high resolution spectrum.
         """
@@ -820,13 +824,31 @@ class Model(TransformerMixin, BaseEstimator):
         B, P, _ = low_res_select.shape
         low_res_select = low_res_select.reshape((B*P, -1))
 
-        n_components = min(self.x_model["pca"].n_components, low_res_select.shape[0])
-        print(f"Using {n_components} comp. for PES PCA (asked for {self.x_model['pca'].n_components}, out of {low_res_select.shape[1]}, in {low_res_select.shape[0]} samples).")
+        # estimate number of PCA components
+        pca_test = PCA(None, whiten=True)
+        pca_test.fit(low_res_select)
+        n_components = np.where(np.cumsum(pca_test.explained_variance_ratio_) > self.pca_threshold)[0]
+        if len(n_components) > 0:
+            n_components = n_components[0]
+
+        print(f"Using {n_components} comp. for PES PCA.")
         self.x_model.set_params(pca__n_components=n_components)
         x_t = self.x_model.fit_transform(low_res_select)
         #print("PCA fraction of variance (LR): ", np.cumsum(self.x_model["pca"].explained_variance_ratio_))
+
         print("Fitting PCA on high-resolution data.")
+        # estimate number of PCA components
+        pca_test = PCA(None, whiten=True)
+        pca_test.fit(high_res_data)
+        n_components_hr = np.where(np.cumsum(pca_test.explained_variance_ratio_) > self.pca_threshold)[0]
+        if len(n_components_hr) > 0:
+            n_components_hr = n_components_hr[0]
+
+        print(f"Using {n_components_hr} comp. for grating spec. PCA.")
+        self.y_model.set_params(pca__n_components=n_components_hr)
+
         y_t = self.y_model.fit_transform(high_res_data, smoothen__energy=high_res_photon_energy)
+
         #print("PCA fraction of variance (HR): ", np.cumsum(self.y_model["pca"].explained_variance_ratio_))
 
         print("Fitting outlier detection")
@@ -887,15 +909,37 @@ class Model(TransformerMixin, BaseEstimator):
         h = np.fft.fftshift(np.fft.ifft(H))
         self.impulse_response = h
 
+        # get grating spec. resolution
+        mean_y = np.mean(y, keepdims=True, axis=0)
+        self.auto_corr_hr = np.mean(np.fft.fftshift(np.fft.ifft(np.absolute(np.fft.fft(y - mean_y))**2), axes=(-1,)), axis=0)
+        self.auto_corr_hr = np.real(self.auto_corr_hr)
+        self.auto_corr_hr /= np.amax(self.auto_corr_hr)
+        try:
+            self.fwhm_hr = fwhm(e_axis, self.auto_corr_hr)
+        except:
+            self.fwhm_hr = -1.0
+
+        # get virtual spectrometer resolution
         mean_y_hat = np.mean(y_hat, keepdims=True, axis=0)
-        self.auto_corr = np.mean(np.fft.fftshift(np.fft.ifft(np.absolute(np.fft.fft(y_hat - mean_y_hat))**2), axes=(-1,)), axis=0)
-        self.auto_corr = np.real(self.auto_corr)
-        self.auto_corr /= np.amax(self.auto_corr)
+        self.auto_corr_virt = np.mean(np.fft.fftshift(np.fft.ifft(np.absolute(np.fft.fft(y_hat - mean_y_hat))**2), axes=(-1,)), axis=0)
+        self.auto_corr_virt = np.real(self.auto_corr_virt)
+        self.auto_corr_virt /= np.amax(self.auto_corr_virt)
         try:
-            self.resolution = fwhm(e_axis, self.auto_corr)
+            self.fwhm_virt = fwhm(e_axis, self.auto_corr_virt)
         except:
-            self.resolution = -1.0
-        print("Resolution:", self.resolution)
+            self.fwhm_virt = -1.0
+
+        self.resolution = -1.0
+        if self.fwhm_hr > 0 and self.fwhm_virt > self.fwhm_hr:
+            self.resolution = np.sqrt(self.fwhm_virt**2 - self.fwhm_hr**2)
+        if self.resolution < 0:
+            print("Warning: Resolution calculation failed. The model can still be used, but this is probably a red flag!")
+        else:
+            print("Resolution:", self.resolution)
+
+        # this speeds things up considerably, when we do not care about that
+        if not ood:
+            return high_res.reshape((B, P, -1))
 
         # for consistency check per channel
         selection_model = self.x_select
diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index ed80b67..65b2beb 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -240,13 +240,30 @@ def main():
     t = list()
     t_names = list()
 
-    model = Model(channels=channels, model_type=args.model_type)
-
     train_idx = np.isin(tids, train_tids) & (xgm_flux[:,0] > args.xgm_cut)
+
     # we just need this for training and we need to avoid copying it, which blows up the memoray usage
     for k in pes_raw.keys():
         pes_raw[k] = pes_raw[k][train_idx]
 
+    # fit model with no smearing to estimate maximum resolution
+    print("Fitting model only to detect resolution")
+    start = time_ns()
+    model = Model(channels=channels, model_type=args.model_type)
+    model.uniformize(xgm_flux[train_idx])
+    model.fit(pes_raw,
+              spec_raw_int[train_idx],
+              spec_raw_pe[train_idx],
+              pulse_energy=xgm_flux[train_idx],
+              )
+    t += [time_ns() - start]
+    t_names += ["Resolution estimate"]
+    resolution = model.resolution
+    print(f"Resolution detected: {resolution} eV")
+
+    # do it again with smearing, but now with the knowledge of the resolution
+    model = Model(channels=channels, model_type=args.model_type, high_res_fwhm=resolution)
+
     model.debug_peak_finding(pes_raw, os.path.join(args.directory, "test_peak_finding.pdf"))
     if len(args.model) == 0:
         print("Fitting")
@@ -289,14 +306,14 @@ def main():
     pca = PCA(None, whiten=True)
     pca.fit(pes_raw_select)
     df = pd.DataFrame(dict(variance_ratio=pca.explained_variance_ratio_,
-                           n_comp=600*np.ones_like(pca.explained_variance_ratio_),
+                           n_comp=model.x_model.get_params()["pca__n_components"]*np.ones_like(pca.explained_variance_ratio_),
                            ))
     df.to_csv(os.path.join(args.directory, "pca_pes.csv"))
 
     pca_spec = PCA(None, whiten=True)
     pca_spec.fit(spec_raw_int[train_idx])
     df = pd.DataFrame(dict(variance_ratio=pca_spec.explained_variance_ratio_,
-                           n_comp=20*np.ones_like(pca_spec.explained_variance_ratio_),
+                           n_comp=model.y_model.get_params()["pca__n_components"]*np.ones_like(pca_spec.explained_variance_ratio_),
                            ))
     df.to_csv(os.path.join(args.directory, "pca_spec.csv"))
 
diff --git a/pes_to_spec/test/prepare_plots.py b/pes_to_spec/test/prepare_plots.py
index 288b66f..6846b2d 100755
--- a/pes_to_spec/test/prepare_plots.py
+++ b/pes_to_spec/test/prepare_plots.py
@@ -115,7 +115,7 @@ def plot_chi2_intensity(df: pd.DataFrame, filename: str):
     #            fill=True,
     #            ax=ax)
     sns.scatterplot(x=df.chi2_prepca/df.ndof.iloc[0], y=df.xgm_flux_t*1e-3,
-                    s=5,
+                    s=20,
                     #alpha=0.4,
                     c="tab:red",
                     #size=df.root_mean_squared_pca_unc,
@@ -360,7 +360,7 @@ def plot_pes(df: pd.DataFrame,
              label: Optional[Dict[str, str]]=None,
              refs: Optional[Dict[str, Dict[int, float]]]=None,
              counts_to_mv: Optional[float]=None,
-             tof: bool=False,
+             tof: bool=True,
              ):
     """
     Plot low-resolution spectrum.
-- 
GitLab