From 34f1e4f0b91eaeead3d245c3e5d4c137c606a2f8 Mon Sep 17 00:00:00 2001 From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de> Date: Mon, 19 Dec 2022 18:14:06 +0100 Subject: [PATCH] Added automatic peak finding and producing debug plots to test it. --- pes_to_spec/model.py | 75 ++++++++++++++++++++++++++++++++++++++-- scripts/test_analysis.py | 4 ++- 2 files changed, 75 insertions(+), 4 deletions(-) diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index fc087c4..41014f6 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -4,10 +4,15 @@ from autograd import grad import joblib import h5py from scipy.signal import fftconvolve +from scipy.signal import find_peaks_cwt from scipy.optimize import fmin_l_bfgs_b from sklearn.decomposition import PCA, IncrementalPCA from sklearn.model_selection import train_test_split +import logging + +import matplotlib.pyplot as plt + from typing import Any, Dict, List, Optional def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray: @@ -15,6 +20,15 @@ def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray: unique_ids = list(set(a).intersection(b).intersection(c)) return np.array(unique_ids) +class PromptNotFoundError(Exception): + """ + Exception representing the error condition generated by not finding the prompt peak. + """ + def __init__(self): + pass + def __str__(self) -> str: + return "No prompt peak has been detected." + class Model(object): """ Object representing a previous fit of the model to be used to predict high-resolution @@ -48,7 +62,7 @@ class Model(object): self.n_pca_hr = n_pca_hr # PCA models - self.lr_pca = IncrementalPCA(n_pca_lr, whiten=True, batch_size=n_pca_lr) + self.lr_pca = IncrementalPCA(n_pca_lr, whiten=True) self.hr_pca = PCA(n_pca_hr, whiten=True) # PCA unc. in high resolution @@ -81,8 +95,10 @@ class Model(object): Returns: Concatenated and pre-processed low-resolution data of shape (train_id, features). """ items = [low_res_data[k] for k in self.channels] - if self.tof_start is not None and self.delta_tof is not None: + if self.delta_tof is not None: items = [item[:, self.tof_start:(self.tof_start + self.delta_tof)] for item in items] + else: + items = [item[:, self.tof_start:] for item in items] cat = np.concatenate(items, axis=1) return cat @@ -104,6 +120,53 @@ class Model(object): high_res_gc = fftconvolve(high_res_data, gaussian, mode="same", axes=1) return high_res_gc + def estimate_prompt_peak(self, low_res_data: Dict[str, np.ndarray]) -> int: + """ + Estimate the prompt peak index. + + Args: + low_res_data: Low resolution data with a dictionary containing the channel names. + + Returns: The prompt peak index. + """ + # reduce on channel and on train ID + sum_low_res = - np.mean(sum(list(low_res_data.values())), axis=0) + widths = np.arange(10, 50, step=5) + peak_idx = find_peaks_cwt(sum_low_res, widths) + if len(peak_idx) < 1: + raise PromptNotFoundError() + peak_idx = sorted(peak_idx, key=lambda k: np.fabs(sum_low_res[k]), reverse=True) + return peak_idx[0] + + def debug_peak_finding(self, low_res_data: Dict[str, np.ndarray], filename: str): + """ + Produce image to understand if the peak finding step worked well. + + Args: + low_res_data: Low resolution data with a dictionary containing the channel names. + filename: The file name where to save the plot. + + """ + sum_low_res = - np.mean(sum(list(low_res_data.values())), axis=0) + peak_idx = self.estimate_prompt_peak(low_res_data) + fig = plt.figure(figsize=(8, 16)) + ax = plt.gca() + ax.plot(np.arange(peak_idx-100, peak_idx+300), + sum_low_res[peak_idx-100:peak_idx+300], + c="b", + label="Data") + ax.set(title="", + xlabel="Photon Spectrometer channel", + ylabel="Sum of all Photon Spectrometer channels") + plt.axvline(100, + linewidth=3, + ls="--", + color='r', + label="Peak position") + ax.legend() + plt.savefig(filename) + plt.close(fig) + def fit(self, low_res_data: Dict[str, np.ndarray], high_res_data: np.ndarray, high_res_photon_energy: np.ndarray) -> np.ndarray: """ Train the model. @@ -118,12 +181,18 @@ class Model(object): self.high_res_photon_energy = high_res_photon_energy + print("Find peaks.") + # if the prompt peak has not been given, guess it + if self.tof_start is None: + self.tof_start = self.estimate_prompt_peak(low_res_data) + print("Prompt at", self.tof_start) + print("Pre-processing low") low_res = self.preprocess_low_res(low_res_data) print("Pre-processing high") high_res = self.preprocess_high_res(high_res_data, high_res_photon_energy) # fit PCA - print("PCA low") + print("PCA low", low_res.shape) low_pca = self.lr_pca.fit_transform(low_res) print("PCA high") high_pca = self.hr_pca.fit_transform(high_res) diff --git a/scripts/test_analysis.py b/scripts/test_analysis.py index 93351c3..56b924c 100755 --- a/scripts/test_analysis.py +++ b/scripts/test_analysis.py @@ -111,10 +111,12 @@ def main(): n_pca_hr=20, high_res_sigma=0.2, tof_start=None, - delta_tof=None, + delta_tof=400, validation_size=0.05) train_idx = np.isin(tids, train_tids) + + model.debug_peak_finding(pes_raw, "test_peak_finding.png") print("Fitting") model.fit({k: v[train_idx, :] for k, v in pes_raw.items()}, spec_raw_int[train_idx, :], spec_raw_pe[train_idx, :]) spec_smooth = model.preprocess_high_res(spec_raw_int, spec_raw_pe) -- GitLab