From d36d795cb2aefcab774f361d98fec562da12ff04 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Tue, 14 Feb 2023 13:53:22 +0100
Subject: [PATCH] Some clean up

---
 pes_to_spec/test/offline_analysis.py | 130 +++++++++++++++++----------
 1 file changed, 82 insertions(+), 48 deletions(-)

diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index a2c2bf1..637ab78 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -4,9 +4,12 @@ import sys
 sys.path.append('.')
 sys.path.append('..')
 
+import os
+import argparse
+
 import numpy as np
-from extra_data import RunDirectory, by_id
-from pes_to_spec.model import Model, matching_ids
+from extra_data import open_run, by_id
+from pes_to_spec.model import Model, matching_two_ids
 
 from itertools import product
 
@@ -15,8 +18,7 @@ matplotlib.use('Agg')
 
 import matplotlib.pyplot as plt
 from matplotlib.gridspec import GridSpec
-from mpl_toolkits.axes_grid.inset_locator import (inset_axes, InsetPosition,
-                                                  mark_inset)
+from mpl_toolkits.axes_grid.inset_locator import InsetPosition
 
 from typing import Dict, Optional
 
@@ -55,7 +57,14 @@ def plot_pes(filename: str, pes_raw_int: np.ndarray, first: int, last: int):
     fig.savefig(filename)
     plt.close(fig)
 
-def plot_result(filename: str, spec_pred: Dict[str, np.ndarray], spec_smooth: np.ndarray, spec_raw_pe: np.ndarray, spec_raw_int: Optional[np.ndarray]=None, pes: Optional[np.ndarray]=None, pes_to_show: Optional[str]="", pes_bin: Optional[np.ndarray]=None):
+def plot_result(filename: str,
+                spec_pred: Dict[str, np.ndarray],
+                spec_smooth: np.ndarray,
+                spec_raw_pe: np.ndarray,
+                spec_raw_int: Optional[np.ndarray]=None,
+                pes: Optional[np.ndarray]=None,
+                pes_to_show: Optional[str]="",
+                pes_bin: Optional[np.ndarray]=None):
     """
     Plot result with uncertainty band.
 
@@ -116,36 +125,58 @@ def main():
     """
     Main entry point. Reads some data, trains and predicts.
     """
-    run_dir = "/gpfs/exfel/exp/SA3/202121/p002935/raw/r0015"
-    run_dir = "/gpfs/exfel/exp/SQS/202201/p002828/raw/r0206"
+    parser = argparse.ArgumentParser(prog="offline_analysis", description="Test pes2spec doing an offline analysis of the data.")
+    parser.add_argument('-p', '--proposal', type=int, meta='INT', help='Proposal number', default=2828)
+    parser.add_argument('-r', '--run', type=int, meta='INT', help='Run number', default=206)
+    parser.add_argument('-m', '--model', type=str, meta='FILENAME', default="", help='Model to load. If given, do not train a model and just do inference with this one.')
+    parser.add_argument('-d', '--directory', type=str, meta='DIRECTORY', default=".", help='Where to save the results.')
+    parser.add_argument('-S', '--spec', type=str, meta='NAME', default="SA3_XTD10_SPECT/MDL/SPECTROMETER_SQS_NAVITAR:output", help='SPEC name')
+    parser.add_argument('-P', '--pes', type=str, meta='NAME', default="SA3_XTD10_PES/ADC/1:network", help='PES name')
+    parser.add_argument('-X', '--xgm', type=str, meta='NAME', default="SA3_XTD10_XGM/XGM/DOOCS:output", help='XGM name')
+    parser.add_argument('-o', '--offset', type=int, meta='INT', default=0, help='Train ID offset')
+
+    args = parser.parse_args()
+
     # get run
-    run = RunDirectory(run_dir)
+    run = open_run(proposal=args.proposal, run=args.run)
 
+    # ----------------Used in the first tests-------------------------
     # get train IDs and match them, so we are sure to have information from all needed sources
     # in this example, there is an offset of -2 in the SPEC train ID, so correct for it
     spec_offset = -2
     spec_name = 'SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output'
     pes_name = 'SA3_XTD10_PES/ADC/1:network'
-    xgm_name = 'SA3_XTD10_XGM/XGM/DOOCS:output'
 
     spec_offset = 0
     spec_name = 'SA3_XTD10_SPECT/MDL/SPECTROMETER_SQS_NAVITAR:output'
     pes_name = 'SA3_XTD10_PES/ADC/1:network'
-    xgm_name = 'SA3_XTD10_XGM/XGM/DOOCS:output'
+    # -------------------End of test setup ----------------------------
+
+    spec_offset = args.offset
+    spec_name = args.spec
+    pes_name = args.pes
+    #xgm_name = args.xgm
 
-    spec_tid = spec_offset + run[spec_name, "data.trainId"].ndarray()
     pes_tid = run[pes_name, "digitizers.trainId"].ndarray()
-    xgm_tid = run[xgm_name, "data.trainId"].ndarray()
-    # these are the train ID intersection
-    # this could have been done by a select call in the RunDirectory, but it would not correct for the spec_offset
-    tids = matching_ids(spec_tid, pes_tid, xgm_tid)
+    #xgm_tid = run[xgm_name, "data.trainId"].ndarray()
+
+    if len(args.model) == 0:
+        spec_tid = spec_offset + run[spec_name, "data.trainId"].ndarray()
+        # these are the train ID intersection
+        # this could have been done by a select call in the RunDirectory, but it would not correct for the spec_offset
+        tids = matching_two_ids(spec_tid, pes_tid)
+
+        # read the spec photon energy and intensity
+        spec_raw_pe = run[spec_name, "data.photonEnergy"].select_trains(by_id[tids - spec_offset]).ndarray()
+        spec_raw_int = run[spec_name, "data.intensityDistribution"].select_trains(by_id[tids - spec_offset]).ndarray()
+
+    else: # when doing inference, no need to load SPEC data
+        tids = pes_tid
+
+    # reserve part of it for the test stage
     train_tids = tids[:-10]
     test_tids = tids[-10:]
 
-    # read the spec photon energy and intensity
-    spec_raw_pe = run[spec_name, "data.photonEnergy"].select_trains(by_id[tids - spec_offset]).ndarray()
-    spec_raw_int = run[spec_name, "data.intensityDistribution"].select_trains(by_id[tids - spec_offset]).ndarray()
-
     # read the PES data for each channel
     channels = [f"channel_{i}_{l}"
                 for i, l in product(range(1, 5), ["A", "B", "C", "D"])]
@@ -163,36 +194,34 @@ def main():
     t = list()
     t_names = list()
 
-    # these have been manually selected:
-    #useful_channels = ["channel_1_D",
-    #                  "channel_2_B",
-    #                  "channel_3_A",
-    #                  "channel_3_B",
-    #                  "channel_4_C",
-    #                  "channel_4_D"]
     model = Model()
 
     train_idx = np.isin(tids, train_tids)
 
-    model.debug_peak_finding(pes_raw, "test_peak_finding.png")
-    print("Fitting")
-    start = time_ns()
-    model.fit({k: v[train_idx, :]
-               for k, v in pes_raw.items()},
-              spec_raw_int[train_idx, :],
-              spec_raw_pe[train_idx, :])
-    t += [time_ns() - start]
-    t_names += ["Fit"]
-
-    print("Saving the model")
-    start = time_ns()
-    model.save("model.joblib")
-    t += [time_ns() - start]
-    t_names += ["Save"]
+    model.debug_peak_finding(pes_raw, os.path.join(args.directory, "test_peak_finding.png"))
+    if len(args.model) == 0:
+        print("Fitting")
+        start = time_ns()
+        model.fit({k: v[train_idx, :]
+                   for k, v in pes_raw.items()},
+                   spec_raw_int[train_idx, :],
+                   spec_raw_pe[train_idx, :])
+        t += [time_ns() - start]
+        t_names += ["Fit"]
+
+        print("Saving the model")
+        start = time_ns()
+        modelFilename = os.path.join(args.directory, "model.joblib")
+        model.save(modelFilename)
+        t += [time_ns() - start]
+        t_names += ["Save"]
+    else:
+        print("Model has been given, so I will just load it.")
+        modelFilename = args.model
 
     print("Loading the model")
     start = time_ns()
-    model = Model.load("model.joblib")
+    model = Model.load(modelFilename)
     t += [time_ns() - start]
     t_names += ["Load"]
 
@@ -218,7 +247,10 @@ def main():
     print(df_time)
 
     print("Plotting")
-    spec_smooth = model.preprocess_high_res(spec_raw_int)
+    showSpec = False
+    if len(args.model) == 0:
+        showSpec = True
+        spec_smooth = model.preprocess_high_res(spec_raw_int)
     first, last = model.get_low_resolution_range()
     first += 10
     last -= 100
@@ -226,19 +258,21 @@ def main():
     # plot
     for tid in test_tids:
         idx = np.where(tid==tids)[0][0]
-        plot_result(f"test_{tid}.png",
+        plot_result(os.path.join(args.directory, f"test_{tid}.png"),
                    {k: item[idx, ...] if k != "pca"
                        else item[0, ...]
                        for k, item in spec_pred.items()},
-                    spec_smooth[idx, :],
-                    spec_raw_pe[idx, :],
-                    spec_raw_int[idx, :],
+                    spec_smooth[idx, :] if showSpec else None,
+                    spec_raw_pe[idx, :] if showSpec else None,
+                    spec_raw_int[idx, :] if showSpec else None,
                     pes=-pes_raw[pes_to_show][idx, first:last],
                     pes_to_show=pes_to_show.replace('_', ' '),
                     pes_bin=np.arange(first, last)
                     )
         for ch in channels:
-            plot_pes(f"test_pes_{tid}_{ch}.png", pes_raw[ch][idx, first:last], first, last)
+            plot_pes(os.path.join(args.directory, f"test_pes_{tid}_{ch}.png"),
+                     pes_raw[ch][idx, first:last], first, last)
 
 if __name__ == '__main__':
     main()
+
-- 
GitLab