From 34f1e4f0b91eaeead3d245c3e5d4c137c606a2f8 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Mon, 19 Dec 2022 18:14:06 +0100
Subject: [PATCH] Added automatic peak finding and producing debug plots to
 test it.

---
 pes_to_spec/model.py     | 75 ++++++++++++++++++++++++++++++++++++++--
 scripts/test_analysis.py |  4 ++-
 2 files changed, 75 insertions(+), 4 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index fc087c4..41014f6 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -4,10 +4,15 @@ from autograd import grad
 import joblib
 import h5py
 from scipy.signal import fftconvolve
+from scipy.signal import find_peaks_cwt
 from scipy.optimize import fmin_l_bfgs_b
 from sklearn.decomposition import PCA, IncrementalPCA
 from sklearn.model_selection import train_test_split
 
+import logging
+
+import matplotlib.pyplot as plt
+
 from typing import Any, Dict, List, Optional
 
 def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
@@ -15,6 +20,15 @@ def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
     unique_ids = list(set(a).intersection(b).intersection(c))
     return np.array(unique_ids)
 
+class PromptNotFoundError(Exception):
+    """
+    Exception representing the error condition generated by not finding the prompt peak.
+    """
+    def __init__(self):
+        pass
+    def __str__(self) -> str:
+        return "No prompt peak has been detected."
+
 class Model(object):
     """
     Object representing a previous fit of the model to be used to predict high-resolution
@@ -48,7 +62,7 @@ class Model(object):
         self.n_pca_hr = n_pca_hr
 
         # PCA models
-        self.lr_pca = IncrementalPCA(n_pca_lr, whiten=True, batch_size=n_pca_lr)
+        self.lr_pca = IncrementalPCA(n_pca_lr, whiten=True)
         self.hr_pca = PCA(n_pca_hr, whiten=True)
 
         # PCA unc. in high resolution
@@ -81,8 +95,10 @@ class Model(object):
         Returns: Concatenated and pre-processed low-resolution data of shape (train_id, features).
         """
         items = [low_res_data[k] for k in self.channels]
-        if self.tof_start is not None and self.delta_tof is not None:
+        if self.delta_tof is not None:
             items = [item[:, self.tof_start:(self.tof_start + self.delta_tof)] for item in items]
+        else:
+            items = [item[:, self.tof_start:] for item in items]
         cat = np.concatenate(items, axis=1)
         return cat
 
@@ -104,6 +120,53 @@ class Model(object):
         high_res_gc = fftconvolve(high_res_data, gaussian, mode="same", axes=1)
         return high_res_gc
 
+    def estimate_prompt_peak(self, low_res_data: Dict[str, np.ndarray]) -> int:
+        """
+        Estimate the prompt peak index.
+
+        Args:
+          low_res_data: Low resolution data with a dictionary containing the channel names.
+
+        Returns: The prompt peak index.
+        """
+        # reduce on channel and on train ID
+        sum_low_res = - np.mean(sum(list(low_res_data.values())), axis=0)
+        widths = np.arange(10, 50, step=5)
+        peak_idx = find_peaks_cwt(sum_low_res, widths)
+        if len(peak_idx) < 1:
+            raise PromptNotFoundError()
+        peak_idx = sorted(peak_idx, key=lambda k: np.fabs(sum_low_res[k]), reverse=True)
+        return peak_idx[0]
+
+    def debug_peak_finding(self, low_res_data: Dict[str, np.ndarray], filename: str):
+        """
+        Produce image to understand if the peak finding step worked well.
+
+        Args:
+          low_res_data: Low resolution data with a dictionary containing the channel names.
+          filename: The file name where to save the plot.
+
+        """
+        sum_low_res = - np.mean(sum(list(low_res_data.values())), axis=0)
+        peak_idx = self.estimate_prompt_peak(low_res_data)
+        fig = plt.figure(figsize=(8, 16))
+        ax = plt.gca()
+        ax.plot(np.arange(peak_idx-100, peak_idx+300),
+                sum_low_res[peak_idx-100:peak_idx+300],
+                c="b",
+                label="Data")
+        ax.set(title="",
+               xlabel="Photon Spectrometer channel",
+               ylabel="Sum of all Photon Spectrometer channels")
+        plt.axvline(100,
+                linewidth=3,
+                ls="--",
+                color='r',
+                label="Peak position")
+        ax.legend()
+        plt.savefig(filename)
+        plt.close(fig)
+
     def fit(self, low_res_data: Dict[str, np.ndarray], high_res_data: np.ndarray, high_res_photon_energy: np.ndarray) -> np.ndarray:
         """
         Train the model.
@@ -118,12 +181,18 @@ class Model(object):
 
         self.high_res_photon_energy = high_res_photon_energy
 
+        print("Find peaks.")
+        # if the prompt peak has not been given, guess it
+        if self.tof_start is None:
+            self.tof_start = self.estimate_prompt_peak(low_res_data)
+            print("Prompt at", self.tof_start)
+
         print("Pre-processing low")
         low_res = self.preprocess_low_res(low_res_data)
         print("Pre-processing high")
         high_res = self.preprocess_high_res(high_res_data, high_res_photon_energy)
         # fit PCA
-        print("PCA low")
+        print("PCA low", low_res.shape)
         low_pca = self.lr_pca.fit_transform(low_res)
         print("PCA high")
         high_pca = self.hr_pca.fit_transform(high_res)
diff --git a/scripts/test_analysis.py b/scripts/test_analysis.py
index 93351c3..56b924c 100755
--- a/scripts/test_analysis.py
+++ b/scripts/test_analysis.py
@@ -111,10 +111,12 @@ def main():
                  n_pca_hr=20,
                  high_res_sigma=0.2,
                  tof_start=None,
-                 delta_tof=None,
+                 delta_tof=400,
                  validation_size=0.05)
 
     train_idx = np.isin(tids, train_tids)
+
+    model.debug_peak_finding(pes_raw, "test_peak_finding.png")
     print("Fitting")
     model.fit({k: v[train_idx, :] for k, v in pes_raw.items()}, spec_raw_int[train_idx, :], spec_raw_pe[train_idx, :])
     spec_smooth = model.preprocess_high_res(spec_raw_int, spec_raw_pe)
-- 
GitLab