From 13def5407ff4ccd023c046f42ba906dce010f9c1 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Wed, 14 Dec 2022 17:11:37 +0100
Subject: [PATCH] Started simplified setup and clean up.

---
 scripts/test_analysis.py | 52 ++++++++++++++++++++++++++++++++++++++++
 1 file changed, 52 insertions(+)
 create mode 100644 scripts/test_analysis.py

diff --git a/scripts/test_analysis.py b/scripts/test_analysis.py
new file mode 100644
index 0000000..51fc27b
--- /dev/null
+++ b/scripts/test_analysis.py
@@ -0,0 +1,52 @@
+#!/usr/bin/env python
+
+from extra_data import RunDirectory, by_id
+from pes_to_spec.model import Model, matching_ids
+
+import matplotlib
+matplotlib.use('Agg')
+
+import matplotlib.pyplot as plt
+
+def main():
+    """
+    Main entry point. Reads some data, trains and predicts.
+    """
+    run_dir = "/gpfs/exfel/exp/SA3/202121/p002935/raw"
+    run_id = "r0015"
+    # get run
+    run = RunDirectory(f"{run_dir}/{run_id}") 
+
+    # get train IDs and match them, so we are sure to have information from all needed sources
+    # in this example, there is an offset of -2 in the SPEC train ID, so correct for it
+    spec_offset = -2
+    spec_tid = spec_offset + run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output', f"data.trainId"].ndarray()
+    pes_tid = run['SA3_XTD10_PES/ADC/1:network', f"digitizers.trainId"].ndarray()
+    xgm_tid = run['SA3_XTD10_XGM/XGM/DOOCS:output', f"data.trainId"].ndarray()
+    # these are the train ID intersection
+    # this could have been done by a select call in the RunDirectory, but it would not correct for the spec_offset
+    tids = matching_ids(spec_tid, pes_tid, xgm_tid)
+
+    # read the spec photon energy and intensity
+    spec_raw_pe = run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output', f"data.photonEnergy"].select_trains(by_id[tids - spec_offset]).ndarray()
+    spec_raw_int = run['SA3_XTD10_SPECT/MDL/FEL_BEAM_SPECTROMETER_SQS1:output', f"data.intensityDistribution"].select_trains(by_id[tids - spec_offset]).ndarray()
+
+    # read the PES data for each channel
+    channels = [f"channel_{i}_{l}" for i, l in zip(range(1, 5), ["A", "B", "C", "D"])]
+    pes_raw = {ch: run['SA3_XTD10_PES/ADC/1:network', f"digitizers.{ch}.raw.samples"].select_trains(by_id[tids]).ndarray() for ch in channels}
+
+    # read the XGM information
+    xgm_pressure = run['SA3_XTD10_XGM/XGM/DOOCS', f"pressure.pressureFiltered.value"].select_trains(by_id[tids]).ndarray()
+    xgm_pe =  run['SA3_XTD10_XGM/XGM/DOOCS:output', f"data.intensitySa3TD"].select_trains(by_id[tids]).ndarray()
+
+    retvol_raw = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.value"].select_trains(by_id[tids]).ndarray()
+    retvol_raw_timestamp = run["SA3_XTD10_PES/MDL/DAQ_MPOD", "u212.timestamp"].select_trains(by_id[tids]).ndarray()
+
+    model = Model()
+    model.fit(pes_raw, spec_raw_int)
+
+    # test
+    model.predict(pes_raw)
+
+if __name__ == '__main__':
+    main()
-- 
GitLab