From d36d795cb2aefcab774f361d98fec562da12ff04 Mon Sep 17 00:00:00 2001 From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de> Date: Tue, 14 Feb 2023 13:53:22 +0100 Subject: [PATCH] Some clean up --- pes_to_spec/test/offline_analysis.py | 130 +++++++++++++++++---------- 1 file changed, 82 insertions(+), 48 deletions(-) diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py index a2c2bf1..637ab78 100755 --- a/pes_to_spec/test/offline_analysis.py +++ b/pes_to_spec/test/offline_analysis.py @@ -4,9 +4,12 @@ import sys sys.path.append('.') sys.path.append('..') +import os +import argparse + import numpy as np -from extra_data import RunDirectory, by_id -from pes_to_spec.model import Model, matching_ids +from extra_data import open_run, by_id +from pes_to_spec.model import Model, matching_two_ids from itertools import product @@ -15,8 +18,7 @@ matplotlib.use('Agg') import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpec -from mpl_toolkits.axes_grid.inset_locator import (inset_axes, InsetPosition, - mark_inset) +from mpl_toolkits.axes_grid.inset_locator import InsetPosition from typing import Dict, Optional @@ -55,7 +57,14 @@ def plot_pes(filename: str, pes_raw_int: np.ndarray, first: int, last: int): fig.savefig(filename) plt.close(fig) -def plot_result(filename: str, spec_pred: Dict[str, np.ndarray], spec_smooth: np.ndarray, spec_raw_pe: np.ndarray, spec_raw_int: Optional[np.ndarray]=None, pes: Optional[np.ndarray]=None, pes_to_show: Optional[str]="", pes_bin: Optional[np.ndarray]=None): +def plot_result(filename: str, + spec_pred: Dict[str, np.ndarray], + spec_smooth: np.ndarray, + spec_raw_pe: np.ndarray, + spec_raw_int: Optional[np.ndarray]=None, + pes: Optional[np.ndarray]=None, + pes_to_show: Optional[str]="", + pes_bin: Optional[np.ndarray]=None): """ Plot result with uncertainty band. @@ -116,36 +125,58 @@ def main(): """ Main entry point. Reads some data, trains and predicts. """ - run_dir = "/gpfs/exfel/exp/SA3/202121/p002935/raw/r0015" - run_dir = "/gpfs/exfel/exp/SQS/202201/p002828/raw/r0206" + parser = argparse.ArgumentParser(prog="offline_analysis", description="Test pes2spec doing an offline analysis of the data.") + parser.add_argument('-p', '--proposal', type=int, meta='INT', help='Proposal number', default=2828) + parser.add_argument('-r', '--run', type=int, meta='INT', help='Run number', default=206) + parser.add_argument('-m', '--model', type=str, meta='FILENAME', default="", help='Model to load. If given, do not train a model and just do inference with this one.') + parser.add_argument('-d', '--directory', type=str, meta='DIRECTORY', default=".", help='Where to save the results.') + parser.add_argument('-S', '--spec', type=str, meta='NAME', default="SA3_XTD10_SPECT/MDL/SPECTROMETER_SQS_NAVITAR:output", help='SPEC name') + parser.add_argument('-P', '--pes', type=str, meta='NAME', default="SA3_XTD10_PES/ADC/1:network", help='PES name') + parser.add_argument('-X', '--xgm', type=str, meta='NAME', default="SA3_XTD10_XGM/XGM/DOOCS:output", help='XGM name') + parser.add_argument('-o', '--offset', type=int, meta='INT', default=0, help='Train ID offset') + + args = parser.parse_args() + # get run - run = RunDirectory(run_dir) + run = open_run(proposal=args.proposal, run=args.run) + # ----------------Used in the first tests------------------------- # get train IDs and match them, so we are sure to have information from all needed sources # in this example, there is an offset of -2 in the SPEC train ID, so correct for it spec_offset = -2 spec_name = 'SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output' pes_name = 'SA3_XTD10_PES/ADC/1:network' - xgm_name = 'SA3_XTD10_XGM/XGM/DOOCS:output' spec_offset = 0 spec_name = 'SA3_XTD10_SPECT/MDL/SPECTROMETER_SQS_NAVITAR:output' pes_name = 'SA3_XTD10_PES/ADC/1:network' - xgm_name = 'SA3_XTD10_XGM/XGM/DOOCS:output' + # -------------------End of test setup ---------------------------- + + spec_offset = args.offset + spec_name = args.spec + pes_name = args.pes + #xgm_name = args.xgm - spec_tid = spec_offset + run[spec_name, "data.trainId"].ndarray() pes_tid = run[pes_name, "digitizers.trainId"].ndarray() - xgm_tid = run[xgm_name, "data.trainId"].ndarray() - # these are the train ID intersection - # this could have been done by a select call in the RunDirectory, but it would not correct for the spec_offset - tids = matching_ids(spec_tid, pes_tid, xgm_tid) + #xgm_tid = run[xgm_name, "data.trainId"].ndarray() + + if len(args.model) == 0: + spec_tid = spec_offset + run[spec_name, "data.trainId"].ndarray() + # these are the train ID intersection + # this could have been done by a select call in the RunDirectory, but it would not correct for the spec_offset + tids = matching_two_ids(spec_tid, pes_tid) + + # read the spec photon energy and intensity + spec_raw_pe = run[spec_name, "data.photonEnergy"].select_trains(by_id[tids - spec_offset]).ndarray() + spec_raw_int = run[spec_name, "data.intensityDistribution"].select_trains(by_id[tids - spec_offset]).ndarray() + + else: # when doing inference, no need to load SPEC data + tids = pes_tid + + # reserve part of it for the test stage train_tids = tids[:-10] test_tids = tids[-10:] - # read the spec photon energy and intensity - spec_raw_pe = run[spec_name, "data.photonEnergy"].select_trains(by_id[tids - spec_offset]).ndarray() - spec_raw_int = run[spec_name, "data.intensityDistribution"].select_trains(by_id[tids - spec_offset]).ndarray() - # read the PES data for each channel channels = [f"channel_{i}_{l}" for i, l in product(range(1, 5), ["A", "B", "C", "D"])] @@ -163,36 +194,34 @@ def main(): t = list() t_names = list() - # these have been manually selected: - #useful_channels = ["channel_1_D", - # "channel_2_B", - # "channel_3_A", - # "channel_3_B", - # "channel_4_C", - # "channel_4_D"] model = Model() train_idx = np.isin(tids, train_tids) - model.debug_peak_finding(pes_raw, "test_peak_finding.png") - print("Fitting") - start = time_ns() - model.fit({k: v[train_idx, :] - for k, v in pes_raw.items()}, - spec_raw_int[train_idx, :], - spec_raw_pe[train_idx, :]) - t += [time_ns() - start] - t_names += ["Fit"] - - print("Saving the model") - start = time_ns() - model.save("model.joblib") - t += [time_ns() - start] - t_names += ["Save"] + model.debug_peak_finding(pes_raw, os.path.join(args.directory, "test_peak_finding.png")) + if len(args.model) == 0: + print("Fitting") + start = time_ns() + model.fit({k: v[train_idx, :] + for k, v in pes_raw.items()}, + spec_raw_int[train_idx, :], + spec_raw_pe[train_idx, :]) + t += [time_ns() - start] + t_names += ["Fit"] + + print("Saving the model") + start = time_ns() + modelFilename = os.path.join(args.directory, "model.joblib") + model.save(modelFilename) + t += [time_ns() - start] + t_names += ["Save"] + else: + print("Model has been given, so I will just load it.") + modelFilename = args.model print("Loading the model") start = time_ns() - model = Model.load("model.joblib") + model = Model.load(modelFilename) t += [time_ns() - start] t_names += ["Load"] @@ -218,7 +247,10 @@ def main(): print(df_time) print("Plotting") - spec_smooth = model.preprocess_high_res(spec_raw_int) + showSpec = False + if len(args.model) == 0: + showSpec = True + spec_smooth = model.preprocess_high_res(spec_raw_int) first, last = model.get_low_resolution_range() first += 10 last -= 100 @@ -226,19 +258,21 @@ def main(): # plot for tid in test_tids: idx = np.where(tid==tids)[0][0] - plot_result(f"test_{tid}.png", + plot_result(os.path.join(args.directory, f"test_{tid}.png"), {k: item[idx, ...] if k != "pca" else item[0, ...] for k, item in spec_pred.items()}, - spec_smooth[idx, :], - spec_raw_pe[idx, :], - spec_raw_int[idx, :], + spec_smooth[idx, :] if showSpec else None, + spec_raw_pe[idx, :] if showSpec else None, + spec_raw_int[idx, :] if showSpec else None, pes=-pes_raw[pes_to_show][idx, first:last], pes_to_show=pes_to_show.replace('_', ' '), pes_bin=np.arange(first, last) ) for ch in channels: - plot_pes(f"test_pes_{tid}_{ch}.png", pes_raw[ch][idx, first:last], first, last) + plot_pes(os.path.join(args.directory, f"test_pes_{tid}_{ch}.png"), + pes_raw[ch][idx, first:last], first, last) if __name__ == '__main__': main() + -- GitLab