Skip to content
Snippets Groups Projects
Commit 99b33851 authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

First draft.

parent d55c33a2
No related branches found
No related tags found
No related merge requests found
...@@ -8,7 +8,7 @@ from scipy.optimize import fmin_l_bfgs_b ...@@ -8,7 +8,7 @@ from scipy.optimize import fmin_l_bfgs_b
from sklearn.decomposition import PCA from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray: def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
"""Returns list of train IDs common to sets a, b and c.""" """Returns list of train IDs common to sets a, b and c."""
...@@ -94,7 +94,7 @@ class Model(object): ...@@ -94,7 +94,7 @@ class Model(object):
high_res_gc = fftconvolve(high_res_data, gaussian, mode="same", axes=1)/80.0 high_res_gc = fftconvolve(high_res_data, gaussian, mode="same", axes=1)/80.0
return high_res_gc return high_res_gc
def fit(self, low_res_data: Dict[str, np.ndarray], high_res_data: np.ndarray, high_res_photon_energy: np.ndarray): def fit(self, low_res_data: Dict[str, np.ndarray], high_res_data: np.ndarray, high_res_photon_energy: np.ndarray) -> np.ndarray:
""" """
Train the model. Train the model.
...@@ -102,6 +102,8 @@ class Model(object): ...@@ -102,6 +102,8 @@ class Model(object):
low_res_data: Low resolution data as a dictionary with the key set to `channel_{i}_{k}`, where i is a number between 1 and 4 and k is a letter between A and D. For each dictionary entry, a numpy array is expected with shape (train_id, ToF channel). low_res_data: Low resolution data as a dictionary with the key set to `channel_{i}_{k}`, where i is a number between 1 and 4 and k is a letter between A and D. For each dictionary entry, a numpy array is expected with shape (train_id, ToF channel).
high_res_data: Reference high resolution data with a one-to-one match to the low resolution data in the train_id dimension. Shape (train_id, ToF channel). high_res_data: Reference high resolution data with a one-to-one match to the low resolution data in the train_id dimension. Shape (train_id, ToF channel).
high_res_photon_energy: Photon energy axis for the high-resolution data. high_res_photon_energy: Photon energy axis for the high-resolution data.
Returns: Smoothened high resolution spectrum.
""" """
self.high_res_photon_energy = high_res_photon_energy self.high_res_photon_energy = high_res_photon_energy
...@@ -117,7 +119,9 @@ class Model(object): ...@@ -117,7 +119,9 @@ class Model(object):
self.fit_model.fit(low_pca_train, high_pca_train, low_pca_test, high_pca_test) self.fit_model.fit(low_pca_train, high_pca_train, low_pca_test, high_pca_test)
high_pca_rec = self.hr_pca.inverse_transform(high_pca) high_pca_rec = self.hr_pca.inverse_transform(high_pca)
self.high_pca_unc = np.sqrt(np.mean((high_res - high_pca_rec)**2, axis=0)) self.high_pca_unc = np.sqrt(np.mean((high_res - high_pca_rec)**2, axis=0, keepdims=True))
return high_res
def predict(self, low_res_data: Dict[str, np.ndarray]) -> np.ndarray: def predict(self, low_res_data: Dict[str, np.ndarray]) -> np.ndarray:
""" """
...@@ -127,15 +131,17 @@ class Model(object): ...@@ -127,15 +131,17 @@ class Model(object):
Args: Args:
low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel). low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
Returns: High resolution data with shape (3, train_id, ToF channel). The component 0 of the first dimension is the predicted spectrum. Components 1 and 2 correspond to two sources of uncertainty. Returns: High resolution data with shape (train_id, ToF channel, 3). The component 0 of the last dimension is the predicted spectrum. Components 1 and 2 correspond to two sources of uncertainty.
""" """
low_res = self.preprocess_low_res(low_res_data) low_res = self.preprocess_low_res(low_res_data)
low_pca = self.lr_pca.transform(low_res) low_pca = self.lr_pca.transform(low_res)
n_trains = low_res.shape[0]
# Get high res. # Get high res.
high_pca = self.fit_model.predict(low_pca, None, None) high_pca = self.fit_model.predict(low_pca)
high_res_predicted = self.hr_pca.inverse_transform(high_pca["Y"]) high_res_predicted = self.hr_pca.inverse_transform(high_pca["Y"])
high_res_unc = self.hr_pca.inverse_transform(high_pca["Y"] + high_pca["Y_eps"]) - high_pca_predicted n_high_res_features = high_res_predicted.shape[1]
result = np.stack((high_res_predicted, high_res_unc, self.high_pca_unc), axis=0) high_res_unc = self.hr_pca.inverse_transform(high_pca["Y"] + high_pca["Y_eps"]) - high_res_predicted
result = np.stack((high_res_predicted, high_res_unc, np.broadcast_to(self.high_pca_unc, (n_trains, n_high_res_features))), axis=2)
return result return result
def save(self, filename: str, lr_pca_filename: str, hr_pca_filename: str): def save(self, filename: str, lr_pca_filename: str, hr_pca_filename: str):
......
#!/usr/bin/env python #!/usr/bin/env python
import sys
sys.path.append('..')
import numpy as np
from extra_data import RunDirectory, by_id from extra_data import RunDirectory, by_id
from pes_to_spec.model import Model, matching_ids from pes_to_spec.model import Model, matching_ids
from itertools import product
import matplotlib import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
...@@ -23,10 +29,15 @@ def plot_result(filename: str, spec_pred: np.ndarray, spec_raw_int: np.ndarray, ...@@ -23,10 +29,15 @@ def plot_result(filename: str, spec_pred: np.ndarray, spec_raw_int: np.ndarray,
fig = plt.figure(figsize=(10, 10)) fig = plt.figure(figsize=(10, 10))
gs = GridSpec(1, 1) gs = GridSpec(1, 1)
ax = fig.add_subplot(gs[0, 0]) ax = fig.add_subplot(gs[0, 0])
ax.plot(spec_raw_pe, spec_raw_int, c='b', lw=3, label="High resolution measurement") eps = np.mean(spec_pred[:, 1])
ax.plot(spec_raw_pe, spec_pred[0,:], c='r', lw=3, label="High resolution prediction") ax.plot(spec_raw_pe, spec_raw_int, c='b', lw=3, label="High resolution measurement (smoothened)")
ax.fill_between(spec_raw_pe, spec_pred[0,:] - spec_pred[1,:], spec_pred[0,:] + spec_pred[1,:], fillcolor='red', alpha=0.6, label="68% unc. (stat.)") ax.plot(spec_raw_pe, spec_pred[:, 0], c='r', lw=3, label="High resolution prediction")
ax.fill_between(spec_raw_pe, spec_pred[0,:] - spec_pred[2,:], spec_pred[0,:] + spec_pred[2,:], fillcolor='magenta', alpha=0.6, label="68% unc. (syst., PCA)") ax.fill_between(spec_raw_pe, spec_pred[:, 0] - spec_pred[:, 1], spec_pred[:, 0] + spec_pred[:, 1], facecolor='red', alpha=0.6, label="68% unc. (stat.)")
ax.fill_between(spec_raw_pe, spec_pred[:, 0] - spec_pred[:, 2], spec_pred[:, 0] + spec_pred[:, 2], facecolor='magenta', alpha=0.6, label="68% unc. (syst., PCA)")
ax.legend()
ax.set(title=f"avg(unc) = {eps}",
xlabel="Photon energy [eV]",
ylabel="Intensity")
fig.savefig(filename) fig.savefig(filename)
plt.close(fig) plt.close(fig)
...@@ -56,7 +67,7 @@ def main(): ...@@ -56,7 +67,7 @@ def main():
spec_raw_int = run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output', f"data.intensityDistribution"].select_trains(by_id[tids - spec_offset]).ndarray() spec_raw_int = run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output', f"data.intensityDistribution"].select_trains(by_id[tids - spec_offset]).ndarray()
# read the PES data for each channel # read the PES data for each channel
channels = [f"channel_{i}_{l}" for i, l in zip(range(1, 5), ["A", "B", "C", "D"])] channels = [f"channel_{i}_{l}" for i, l in product(range(1, 5), ["A", "B", "C", "D"])]
pes_raw = {ch: run['SA3_XTD10_PES/ADC/1:network', f"digitizers.{ch}.raw.samples"].select_trains(by_id[tids]).ndarray() for ch in channels} pes_raw = {ch: run['SA3_XTD10_PES/ADC/1:network', f"digitizers.{ch}.raw.samples"].select_trains(by_id[tids]).ndarray() for ch in channels}
# read the XGM information # read the XGM information
...@@ -66,14 +77,17 @@ def main(): ...@@ -66,14 +77,17 @@ def main():
#retvol_raw_timestamp = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.timestamp"].select_trains(by_id[tids]).ndarray() #retvol_raw_timestamp = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.timestamp"].select_trains(by_id[tids]).ndarray()
model = Model() model = Model()
model.fit({k: v[train_tids,:] for k, v in pes_raw}, spec_raw_int[train_tids,:], spec_raw_pe[train_tids, :]) train_idx = np.isin(tids, train_tids)
model.fit({k: v[train_idx, :] for k, v in pes_raw.items()}, spec_raw_int[train_idx, :], spec_raw_pe[train_idx, :])
spec_smooth = model.preprocess_high_res(spec_raw_int, spec_raw_pe)
# test # test
spec_pred = model.predict(pes_raw) spec_pred = model.predict(pes_raw)
# plot # plot
for tid in test_tids: for tid in test_tids:
plot_result(f"test_{tid}.png", spec_pred[:, tid, :], spec_raw_int[tid, :], spec_raw_pe[0, :]) idx = np.where(tid==tids)[0][0]
plot_result(f"test_{tid}.png", spec_pred[idx, :, :], spec_smooth[idx, :], spec_raw_pe[idx, :], eps)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment