Skip to content
Snippets Groups Projects
Commit d36d795c authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Some clean up

parent ecb6b083
No related branches found
No related tags found
1 merge request!7Improve fit
...@@ -4,9 +4,12 @@ import sys ...@@ -4,9 +4,12 @@ import sys
sys.path.append('.') sys.path.append('.')
sys.path.append('..') sys.path.append('..')
import os
import argparse
import numpy as np import numpy as np
from extra_data import RunDirectory, by_id from extra_data import open_run, by_id
from pes_to_spec.model import Model, matching_ids from pes_to_spec.model import Model, matching_two_ids
from itertools import product from itertools import product
...@@ -15,8 +18,7 @@ matplotlib.use('Agg') ...@@ -15,8 +18,7 @@ matplotlib.use('Agg')
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid.inset_locator import (inset_axes, InsetPosition, from mpl_toolkits.axes_grid.inset_locator import InsetPosition
mark_inset)
from typing import Dict, Optional from typing import Dict, Optional
...@@ -55,7 +57,14 @@ def plot_pes(filename: str, pes_raw_int: np.ndarray, first: int, last: int): ...@@ -55,7 +57,14 @@ def plot_pes(filename: str, pes_raw_int: np.ndarray, first: int, last: int):
fig.savefig(filename) fig.savefig(filename)
plt.close(fig) plt.close(fig)
def plot_result(filename: str, spec_pred: Dict[str, np.ndarray], spec_smooth: np.ndarray, spec_raw_pe: np.ndarray, spec_raw_int: Optional[np.ndarray]=None, pes: Optional[np.ndarray]=None, pes_to_show: Optional[str]="", pes_bin: Optional[np.ndarray]=None): def plot_result(filename: str,
spec_pred: Dict[str, np.ndarray],
spec_smooth: np.ndarray,
spec_raw_pe: np.ndarray,
spec_raw_int: Optional[np.ndarray]=None,
pes: Optional[np.ndarray]=None,
pes_to_show: Optional[str]="",
pes_bin: Optional[np.ndarray]=None):
""" """
Plot result with uncertainty band. Plot result with uncertainty band.
...@@ -116,36 +125,58 @@ def main(): ...@@ -116,36 +125,58 @@ def main():
""" """
Main entry point. Reads some data, trains and predicts. Main entry point. Reads some data, trains and predicts.
""" """
run_dir = "/gpfs/exfel/exp/SA3/202121/p002935/raw/r0015" parser = argparse.ArgumentParser(prog="offline_analysis", description="Test pes2spec doing an offline analysis of the data.")
run_dir = "/gpfs/exfel/exp/SQS/202201/p002828/raw/r0206" parser.add_argument('-p', '--proposal', type=int, meta='INT', help='Proposal number', default=2828)
parser.add_argument('-r', '--run', type=int, meta='INT', help='Run number', default=206)
parser.add_argument('-m', '--model', type=str, meta='FILENAME', default="", help='Model to load. If given, do not train a model and just do inference with this one.')
parser.add_argument('-d', '--directory', type=str, meta='DIRECTORY', default=".", help='Where to save the results.')
parser.add_argument('-S', '--spec', type=str, meta='NAME', default="SA3_XTD10_SPECT/MDL/SPECTROMETER_SQS_NAVITAR:output", help='SPEC name')
parser.add_argument('-P', '--pes', type=str, meta='NAME', default="SA3_XTD10_PES/ADC/1:network", help='PES name')
parser.add_argument('-X', '--xgm', type=str, meta='NAME', default="SA3_XTD10_XGM/XGM/DOOCS:output", help='XGM name')
parser.add_argument('-o', '--offset', type=int, meta='INT', default=0, help='Train ID offset')
args = parser.parse_args()
# get run # get run
run = RunDirectory(run_dir) run = open_run(proposal=args.proposal, run=args.run)
# ----------------Used in the first tests-------------------------
# get train IDs and match them, so we are sure to have information from all needed sources # get train IDs and match them, so we are sure to have information from all needed sources
# in this example, there is an offset of -2 in the SPEC train ID, so correct for it # in this example, there is an offset of -2 in the SPEC train ID, so correct for it
spec_offset = -2 spec_offset = -2
spec_name = 'SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output' spec_name = 'SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output'
pes_name = 'SA3_XTD10_PES/ADC/1:network' pes_name = 'SA3_XTD10_PES/ADC/1:network'
xgm_name = 'SA3_XTD10_XGM/XGM/DOOCS:output'
spec_offset = 0 spec_offset = 0
spec_name = 'SA3_XTD10_SPECT/MDL/SPECTROMETER_SQS_NAVITAR:output' spec_name = 'SA3_XTD10_SPECT/MDL/SPECTROMETER_SQS_NAVITAR:output'
pes_name = 'SA3_XTD10_PES/ADC/1:network' pes_name = 'SA3_XTD10_PES/ADC/1:network'
xgm_name = 'SA3_XTD10_XGM/XGM/DOOCS:output' # -------------------End of test setup ----------------------------
spec_offset = args.offset
spec_name = args.spec
pes_name = args.pes
#xgm_name = args.xgm
spec_tid = spec_offset + run[spec_name, "data.trainId"].ndarray()
pes_tid = run[pes_name, "digitizers.trainId"].ndarray() pes_tid = run[pes_name, "digitizers.trainId"].ndarray()
xgm_tid = run[xgm_name, "data.trainId"].ndarray() #xgm_tid = run[xgm_name, "data.trainId"].ndarray()
# these are the train ID intersection
# this could have been done by a select call in the RunDirectory, but it would not correct for the spec_offset if len(args.model) == 0:
tids = matching_ids(spec_tid, pes_tid, xgm_tid) spec_tid = spec_offset + run[spec_name, "data.trainId"].ndarray()
# these are the train ID intersection
# this could have been done by a select call in the RunDirectory, but it would not correct for the spec_offset
tids = matching_two_ids(spec_tid, pes_tid)
# read the spec photon energy and intensity
spec_raw_pe = run[spec_name, "data.photonEnergy"].select_trains(by_id[tids - spec_offset]).ndarray()
spec_raw_int = run[spec_name, "data.intensityDistribution"].select_trains(by_id[tids - spec_offset]).ndarray()
else: # when doing inference, no need to load SPEC data
tids = pes_tid
# reserve part of it for the test stage
train_tids = tids[:-10] train_tids = tids[:-10]
test_tids = tids[-10:] test_tids = tids[-10:]
# read the spec photon energy and intensity
spec_raw_pe = run[spec_name, "data.photonEnergy"].select_trains(by_id[tids - spec_offset]).ndarray()
spec_raw_int = run[spec_name, "data.intensityDistribution"].select_trains(by_id[tids - spec_offset]).ndarray()
# read the PES data for each channel # read the PES data for each channel
channels = [f"channel_{i}_{l}" channels = [f"channel_{i}_{l}"
for i, l in product(range(1, 5), ["A", "B", "C", "D"])] for i, l in product(range(1, 5), ["A", "B", "C", "D"])]
...@@ -163,36 +194,34 @@ def main(): ...@@ -163,36 +194,34 @@ def main():
t = list() t = list()
t_names = list() t_names = list()
# these have been manually selected:
#useful_channels = ["channel_1_D",
# "channel_2_B",
# "channel_3_A",
# "channel_3_B",
# "channel_4_C",
# "channel_4_D"]
model = Model() model = Model()
train_idx = np.isin(tids, train_tids) train_idx = np.isin(tids, train_tids)
model.debug_peak_finding(pes_raw, "test_peak_finding.png") model.debug_peak_finding(pes_raw, os.path.join(args.directory, "test_peak_finding.png"))
print("Fitting") if len(args.model) == 0:
start = time_ns() print("Fitting")
model.fit({k: v[train_idx, :] start = time_ns()
for k, v in pes_raw.items()}, model.fit({k: v[train_idx, :]
spec_raw_int[train_idx, :], for k, v in pes_raw.items()},
spec_raw_pe[train_idx, :]) spec_raw_int[train_idx, :],
t += [time_ns() - start] spec_raw_pe[train_idx, :])
t_names += ["Fit"] t += [time_ns() - start]
t_names += ["Fit"]
print("Saving the model")
start = time_ns() print("Saving the model")
model.save("model.joblib") start = time_ns()
t += [time_ns() - start] modelFilename = os.path.join(args.directory, "model.joblib")
t_names += ["Save"] model.save(modelFilename)
t += [time_ns() - start]
t_names += ["Save"]
else:
print("Model has been given, so I will just load it.")
modelFilename = args.model
print("Loading the model") print("Loading the model")
start = time_ns() start = time_ns()
model = Model.load("model.joblib") model = Model.load(modelFilename)
t += [time_ns() - start] t += [time_ns() - start]
t_names += ["Load"] t_names += ["Load"]
...@@ -218,7 +247,10 @@ def main(): ...@@ -218,7 +247,10 @@ def main():
print(df_time) print(df_time)
print("Plotting") print("Plotting")
spec_smooth = model.preprocess_high_res(spec_raw_int) showSpec = False
if len(args.model) == 0:
showSpec = True
spec_smooth = model.preprocess_high_res(spec_raw_int)
first, last = model.get_low_resolution_range() first, last = model.get_low_resolution_range()
first += 10 first += 10
last -= 100 last -= 100
...@@ -226,19 +258,21 @@ def main(): ...@@ -226,19 +258,21 @@ def main():
# plot # plot
for tid in test_tids: for tid in test_tids:
idx = np.where(tid==tids)[0][0] idx = np.where(tid==tids)[0][0]
plot_result(f"test_{tid}.png", plot_result(os.path.join(args.directory, f"test_{tid}.png"),
{k: item[idx, ...] if k != "pca" {k: item[idx, ...] if k != "pca"
else item[0, ...] else item[0, ...]
for k, item in spec_pred.items()}, for k, item in spec_pred.items()},
spec_smooth[idx, :], spec_smooth[idx, :] if showSpec else None,
spec_raw_pe[idx, :], spec_raw_pe[idx, :] if showSpec else None,
spec_raw_int[idx, :], spec_raw_int[idx, :] if showSpec else None,
pes=-pes_raw[pes_to_show][idx, first:last], pes=-pes_raw[pes_to_show][idx, first:last],
pes_to_show=pes_to_show.replace('_', ' '), pes_to_show=pes_to_show.replace('_', ' '),
pes_bin=np.arange(first, last) pes_bin=np.arange(first, last)
) )
for ch in channels: for ch in channels:
plot_pes(f"test_pes_{tid}_{ch}.png", pes_raw[ch][idx, first:last], first, last) plot_pes(os.path.join(args.directory, f"test_pes_{tid}_{ch}.png"),
pes_raw[ch][idx, first:last], first, last)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment