Skip to content
Snippets Groups Projects
Commit 34f1e4f0 authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Added automatic peak finding and producing debug plots to test it.

parent ff1a99d6
No related branches found
No related tags found
No related merge requests found
...@@ -4,10 +4,15 @@ from autograd import grad ...@@ -4,10 +4,15 @@ from autograd import grad
import joblib import joblib
import h5py import h5py
from scipy.signal import fftconvolve from scipy.signal import fftconvolve
from scipy.signal import find_peaks_cwt
from scipy.optimize import fmin_l_bfgs_b from scipy.optimize import fmin_l_bfgs_b
from sklearn.decomposition import PCA, IncrementalPCA from sklearn.decomposition import PCA, IncrementalPCA
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
import logging
import matplotlib.pyplot as plt
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray: def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
...@@ -15,6 +20,15 @@ def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray: ...@@ -15,6 +20,15 @@ def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
unique_ids = list(set(a).intersection(b).intersection(c)) unique_ids = list(set(a).intersection(b).intersection(c))
return np.array(unique_ids) return np.array(unique_ids)
class PromptNotFoundError(Exception):
"""
Exception representing the error condition generated by not finding the prompt peak.
"""
def __init__(self):
pass
def __str__(self) -> str:
return "No prompt peak has been detected."
class Model(object): class Model(object):
""" """
Object representing a previous fit of the model to be used to predict high-resolution Object representing a previous fit of the model to be used to predict high-resolution
...@@ -48,7 +62,7 @@ class Model(object): ...@@ -48,7 +62,7 @@ class Model(object):
self.n_pca_hr = n_pca_hr self.n_pca_hr = n_pca_hr
# PCA models # PCA models
self.lr_pca = IncrementalPCA(n_pca_lr, whiten=True, batch_size=n_pca_lr) self.lr_pca = IncrementalPCA(n_pca_lr, whiten=True)
self.hr_pca = PCA(n_pca_hr, whiten=True) self.hr_pca = PCA(n_pca_hr, whiten=True)
# PCA unc. in high resolution # PCA unc. in high resolution
...@@ -81,8 +95,10 @@ class Model(object): ...@@ -81,8 +95,10 @@ class Model(object):
Returns: Concatenated and pre-processed low-resolution data of shape (train_id, features). Returns: Concatenated and pre-processed low-resolution data of shape (train_id, features).
""" """
items = [low_res_data[k] for k in self.channels] items = [low_res_data[k] for k in self.channels]
if self.tof_start is not None and self.delta_tof is not None: if self.delta_tof is not None:
items = [item[:, self.tof_start:(self.tof_start + self.delta_tof)] for item in items] items = [item[:, self.tof_start:(self.tof_start + self.delta_tof)] for item in items]
else:
items = [item[:, self.tof_start:] for item in items]
cat = np.concatenate(items, axis=1) cat = np.concatenate(items, axis=1)
return cat return cat
...@@ -104,6 +120,53 @@ class Model(object): ...@@ -104,6 +120,53 @@ class Model(object):
high_res_gc = fftconvolve(high_res_data, gaussian, mode="same", axes=1) high_res_gc = fftconvolve(high_res_data, gaussian, mode="same", axes=1)
return high_res_gc return high_res_gc
def estimate_prompt_peak(self, low_res_data: Dict[str, np.ndarray]) -> int:
"""
Estimate the prompt peak index.
Args:
low_res_data: Low resolution data with a dictionary containing the channel names.
Returns: The prompt peak index.
"""
# reduce on channel and on train ID
sum_low_res = - np.mean(sum(list(low_res_data.values())), axis=0)
widths = np.arange(10, 50, step=5)
peak_idx = find_peaks_cwt(sum_low_res, widths)
if len(peak_idx) < 1:
raise PromptNotFoundError()
peak_idx = sorted(peak_idx, key=lambda k: np.fabs(sum_low_res[k]), reverse=True)
return peak_idx[0]
def debug_peak_finding(self, low_res_data: Dict[str, np.ndarray], filename: str):
"""
Produce image to understand if the peak finding step worked well.
Args:
low_res_data: Low resolution data with a dictionary containing the channel names.
filename: The file name where to save the plot.
"""
sum_low_res = - np.mean(sum(list(low_res_data.values())), axis=0)
peak_idx = self.estimate_prompt_peak(low_res_data)
fig = plt.figure(figsize=(8, 16))
ax = plt.gca()
ax.plot(np.arange(peak_idx-100, peak_idx+300),
sum_low_res[peak_idx-100:peak_idx+300],
c="b",
label="Data")
ax.set(title="",
xlabel="Photon Spectrometer channel",
ylabel="Sum of all Photon Spectrometer channels")
plt.axvline(100,
linewidth=3,
ls="--",
color='r',
label="Peak position")
ax.legend()
plt.savefig(filename)
plt.close(fig)
def fit(self, low_res_data: Dict[str, np.ndarray], high_res_data: np.ndarray, high_res_photon_energy: np.ndarray) -> np.ndarray: def fit(self, low_res_data: Dict[str, np.ndarray], high_res_data: np.ndarray, high_res_photon_energy: np.ndarray) -> np.ndarray:
""" """
Train the model. Train the model.
...@@ -118,12 +181,18 @@ class Model(object): ...@@ -118,12 +181,18 @@ class Model(object):
self.high_res_photon_energy = high_res_photon_energy self.high_res_photon_energy = high_res_photon_energy
print("Find peaks.")
# if the prompt peak has not been given, guess it
if self.tof_start is None:
self.tof_start = self.estimate_prompt_peak(low_res_data)
print("Prompt at", self.tof_start)
print("Pre-processing low") print("Pre-processing low")
low_res = self.preprocess_low_res(low_res_data) low_res = self.preprocess_low_res(low_res_data)
print("Pre-processing high") print("Pre-processing high")
high_res = self.preprocess_high_res(high_res_data, high_res_photon_energy) high_res = self.preprocess_high_res(high_res_data, high_res_photon_energy)
# fit PCA # fit PCA
print("PCA low") print("PCA low", low_res.shape)
low_pca = self.lr_pca.fit_transform(low_res) low_pca = self.lr_pca.fit_transform(low_res)
print("PCA high") print("PCA high")
high_pca = self.hr_pca.fit_transform(high_res) high_pca = self.hr_pca.fit_transform(high_res)
......
...@@ -111,10 +111,12 @@ def main(): ...@@ -111,10 +111,12 @@ def main():
n_pca_hr=20, n_pca_hr=20,
high_res_sigma=0.2, high_res_sigma=0.2,
tof_start=None, tof_start=None,
delta_tof=None, delta_tof=400,
validation_size=0.05) validation_size=0.05)
train_idx = np.isin(tids, train_tids) train_idx = np.isin(tids, train_tids)
model.debug_peak_finding(pes_raw, "test_peak_finding.png")
print("Fitting") print("Fitting")
model.fit({k: v[train_idx, :] for k, v in pes_raw.items()}, spec_raw_int[train_idx, :], spec_raw_pe[train_idx, :]) model.fit({k: v[train_idx, :] for k, v in pes_raw.items()}, spec_raw_int[train_idx, :], spec_raw_pe[train_idx, :])
spec_smooth = model.preprocess_high_res(spec_raw_int, spec_raw_pe) spec_smooth = model.preprocess_high_res(spec_raw_int, spec_raw_pe)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment