Skip to content
Snippets Groups Projects
Commit d55c33a2 authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Saving PCA and plotting.

parent c9f1bc6e
No related branches found
No related tags found
No related merge requests found
import numpy as np import numpy as np
from autograd import numpy as anp from autograd import numpy as anp
from autograd import grad from autograd import grad
import joblib
import h5py import h5py
from scipy.signal import fftconvolve from scipy.signal import fftconvolve
from scipy.optimize import fmin_l_bfgs_b from scipy.optimize import fmin_l_bfgs_b
...@@ -137,12 +138,14 @@ class Model(object): ...@@ -137,12 +138,14 @@ class Model(object):
result = np.stack((high_res_predicted, high_res_unc, self.high_pca_unc), axis=0) result = np.stack((high_res_predicted, high_res_unc, self.high_pca_unc), axis=0)
return result return result
def save(self, filename: str): def save(self, filename: str, lr_pca_filename: str, hr_pca_filename: str):
""" """
Save the fit model in a file. Save the fit model in a file.
Args: Args:
filename: H5 file name where to save this. filename: H5 file name where to save this.
lr_pca_filename: Name of the file where to save the low-resolution PCA decomposition.
hr_pca_filename: Name of the file where to save the high-resolution PCA decomposition.
""" """
with h5py.File(filename, 'w') as hf: with h5py.File(filename, 'w') as hf:
d = self.fit_model.as_dict() d = self.fit_model.as_dict()
...@@ -151,19 +154,26 @@ class Model(object): ...@@ -151,19 +154,26 @@ class Model(object):
hf.attrs[key] = value hf.attrs[key] = value
else: else:
hf.create_dataset(key, data=value) hf.create_dataset(key, data=value)
joblib.dump(self.lr_pca, lr_pca_filename)
joblib.dump(self.hr_pca, hr_pca_filename)
def load(self, filename: str):
def load(self, filename: str, lr_pca_filename: str, hr_pca_filename: str):
""" """
Load model from a file. Load model from a file.
Args: Args:
filename: Name of the file where to read the model from. filename: Name of the file where to read the model from.
lr_pca_filename: Name of the file from where to load the low-resolution PCA decomposition.
hr_pca_filename: Name of the file from where to load the high-resolution PCA decomposition.
""" """
with h5py.File(filename, 'r') as hf: with h5py.File(filename, 'r') as hf:
d = {k: hf[k][()] for k in hf.keys()} d = {k: hf[k][()] for k in hf.keys()}
d.update({k: hf.attrs[k] for k in hf.attrs}) d.update({k: hf.attrs[k] for k in hf.attrs})
self.fit_model.from_dict(d) self.fit_model.from_dict(d)
self.lr_pca = joblib.load(lr_pca_filename)
self.hr_pca = joblib.load(hr_pca_filename)
class FitModel(object): class FitModel(object):
""" """
......
...@@ -4,4 +4,5 @@ scikit-learn ...@@ -4,4 +4,5 @@ scikit-learn
extra_data extra_data
autograd autograd
h5py h5py
joblib
matplotlib matplotlib
...@@ -7,6 +7,28 @@ import matplotlib ...@@ -7,6 +7,28 @@ import matplotlib
matplotlib.use('Agg') matplotlib.use('Agg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
def plot_result(filename: str, spec_pred: np.ndarray, spec_raw_int: np.ndarray, spec_raw_pe: np.ndarray):
"""
Plot result with uncertainty band.
Args:
filename: Output file name.
spec_pred: Predicted result with uncertainty bands in a shape of (3, features).
spec_raw_int: True expected result with shape (features,).
spec_raw_pe: x axis with the photon energy in eV.
"""
fig = plt.figure(figsize=(10, 10))
gs = GridSpec(1, 1)
ax = fig.add_subplot(gs[0, 0])
ax.plot(spec_raw_pe, spec_raw_int, c='b', lw=3, label="High resolution measurement")
ax.plot(spec_raw_pe, spec_pred[0,:], c='r', lw=3, label="High resolution prediction")
ax.fill_between(spec_raw_pe, spec_pred[0,:] - spec_pred[1,:], spec_pred[0,:] + spec_pred[1,:], fillcolor='red', alpha=0.6, label="68% unc. (stat.)")
ax.fill_between(spec_raw_pe, spec_pred[0,:] - spec_pred[2,:], spec_pred[0,:] + spec_pred[2,:], fillcolor='magenta', alpha=0.6, label="68% unc. (syst., PCA)")
fig.savefig(filename)
plt.close(fig)
def main(): def main():
""" """
...@@ -26,6 +48,8 @@ def main(): ...@@ -26,6 +48,8 @@ def main():
# these are the train ID intersection # these are the train ID intersection
# this could have been done by a select call in the RunDirectory, but it would not correct for the spec_offset # this could have been done by a select call in the RunDirectory, but it would not correct for the spec_offset
tids = matching_ids(spec_tid, pes_tid, xgm_tid) tids = matching_ids(spec_tid, pes_tid, xgm_tid)
train_tids = tids[:-10]
test_tids = tids[-10:]
# read the spec photon energy and intensity # read the spec photon energy and intensity
spec_raw_pe = run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output', f"data.photonEnergy"].select_trains(by_id[tids - spec_offset]).ndarray() spec_raw_pe = run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output', f"data.photonEnergy"].select_trains(by_id[tids - spec_offset]).ndarray()
...@@ -36,17 +60,20 @@ def main(): ...@@ -36,17 +60,20 @@ def main():
pes_raw = {ch: run['SA3_XTD10_PES/ADC/1:network', f"digitizers.{ch}.raw.samples"].select_trains(by_id[tids]).ndarray() for ch in channels} pes_raw = {ch: run['SA3_XTD10_PES/ADC/1:network', f"digitizers.{ch}.raw.samples"].select_trains(by_id[tids]).ndarray() for ch in channels}
# read the XGM information # read the XGM information
xgm_pressure = run['SA3_XTD10_XGM/XGM/DOOCS', f"pressure.pressureFiltered.value"].select_trains(by_id[tids]).ndarray() #xgm_pressure = run['SA3_XTD10_XGM/XGM/DOOCS', f"pressure.pressureFiltered.value"].select_trains(by_id[tids]).ndarray()
xgm_pe = run['SA3_XTD10_XGM/XGM/DOOCS:output', f"data.intensitySa3TD"].select_trains(by_id[tids]).ndarray() #xgm_pe = run['SA3_XTD10_XGM/XGM/DOOCS:output', f"data.intensitySa3TD"].select_trains(by_id[tids]).ndarray()
#retvol_raw = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.value"].select_trains(by_id[tids]).ndarray()
retvol_raw = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.value"].select_trains(by_id[tids]).ndarray() #retvol_raw_timestamp = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.timestamp"].select_trains(by_id[tids]).ndarray()
retvol_raw_timestamp = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.timestamp"].select_trains(by_id[tids]).ndarray()
model = Model() model = Model()
model.fit(pes_raw, spec_raw_int) model.fit({k: v[train_tids,:] for k, v in pes_raw}, spec_raw_int[train_tids,:], spec_raw_pe[train_tids, :])
# test # test
model.predict(pes_raw) spec_pred = model.predict(pes_raw)
# plot
for tid in test_tids:
plot_result(f"test_{tid}.png", spec_pred[:, tid, :], spec_raw_int[tid, :], spec_raw_pe[0, :])
if __name__ == '__main__': if __name__ == '__main__':
main() main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment