From 1cb028d8c22179d27d3305444ecef5bf36b195bb Mon Sep 17 00:00:00 2001
From: Laurent Mercadier <laurent.mercadier@xfel.eu>
Date: Mon, 26 Aug 2024 14:05:19 +0200
Subject: [PATCH] replaced source, key and digitizer arguments by mnemonic in
 get_peaks()

---
 src/toolbox_scs/detectors/digitizers.py | 80 +++++++++----------------
 1 file changed, 27 insertions(+), 53 deletions(-)

diff --git a/src/toolbox_scs/detectors/digitizers.py b/src/toolbox_scs/detectors/digitizers.py
index 7fe7c3c..8367242 100644
--- a/src/toolbox_scs/detectors/digitizers.py
+++ b/src/toolbox_scs/detectors/digitizers.py
@@ -168,10 +168,8 @@ def peaks_from_apd(array, params, digitizer, bpt, bunchPattern):
 
 
 def get_peaks(run,
-              data=None,
-              source=None,
-              key=None,
-              digitizer='ADQ412',
+              data,
+              mnemonic,
               useRaw=True,
               autoFind=True,
               integParams=None,
@@ -191,15 +189,10 @@ def get_peaks(run,
         array containing the raw traces or peak-integrated values from the
         digitizer. If str, must be one of the ToolBox mnemonics. If None,
         the data is loaded via the source and key arguments.
-    source: str
-        Name of digitizer source, e.g. 'SCS_UTC1_ADQ/ADC/1:network'. Only
-        required if data is a DataArray or None.
-    key: str
-        Key for digitizer data, e.g. 'digitizers.channel_1_A.raw.samples'.
-        Only required if data is a DataArray or is None.
-    digitizer: string
-        name of digitizer, e.g. 'FastADC' or 'ADQ412'. Used to determine
-        the sampling rate.
+    mnemonic: str or dict
+        ToolBox mnemonic or dict with source and key as in 
+        {'source': 'SCS_UTC1_ADQ/ADC/1:network',
+        'key': 'digitizers.channel_1_A.raw.samples'}
     useRaw: bool
         If True, extract peaks from raw traces. If False, uses the APD (or
         peaks) data from the digitizer.
@@ -211,12 +204,10 @@ def get_peaks(run,
         integration: 'pulseStart', 'pulseStop', 'baseStart', 'baseStop',
         'period', 'npulses'. Not used if autoFind is True. All keys are
         required when bunch pattern is missing.
-    bunchPattern: string or dict
+    bunchPattern: str
         match the peaks to the bunch pattern: 'sase1', 'sase3', 'scs_ppl'.
         This will dictate the name of the pulse ID coordinates: 'sa1_pId',
-        'sa3_pId' or 'scs_ppl'. Alternatively, a dict with source, key and
-        pattern can be provided, e.g. {'source':'SCS_RR_UTC/TSYS/TIMESERVER',
-        'key':'bunchPatternTable.value', 'pattern':'sase3'}
+        'sa3_pId' or 'scs_ppl'.
     bpt: xarray DataArray
         bunch pattern table
     extra_dim: str
@@ -230,43 +221,25 @@ def get_peaks(run,
     -------
     xarray.DataArray containing digitizer peaks with pulse coordinates
     """
-    if data is None and (source is None or key is None):
-        raise ValueError('At least data or source + key arguments '
-                         'are required.')
-    # Load data
-    arr = None
-    run_mnemonics = mnemonics_for_run(run)
-    if data is None:
-        log.debug(f'Loading array from DataCollection with {source}, {key}')
-        arr = run.get_array(source, key)
-    if isinstance(data, str):
-        log.debug(f'Loading array from mnemonic {data}')
-        m = run_mnemonics[data]
-        source = m['source']
-        key = m['key']
-        arr = run.get_array(source, key, m['dim'])
-    if arr is None:
-        log.debug('Using array provided in data argument.')
-        if source is None or key is None:
-            raise ValueError('source and/or key arguments missing.')
-        arr = data
+    arr = data
     dim = [d for d in arr.dims if d != 'trainId'][0]
 
     # Load bunch pattern table
+    run_mnemonics = mnemonics_for_run(run)
     if bpt is None and bunchPattern != 'None':
-        if isinstance(bunchPattern, dict):
-            bpt = run.get_array(bunchPattern['source'], bunchPattern['key'],
-                                extra_dims=['pulse_slot'])
-            pattern = bunchPattern['pattern']
         if 'bunchPatternTable' in run_mnemonics:
             m = run_mnemonics['bunchPatternTable']
             bpt = run.get_array(m['source'], m['key'], m['dim'])
-            pattern = bunchPattern
+        pattern = bunchPattern
     else:
         pattern = bunchPattern
     if bunchPattern == 'None':
         bpt = None
 
+    # Find digitizer type
+    m = mnemonic if isinstance(mnemonic, dict) else run_mnemonics[mnemonic]
+    digitizer = digitizer_type(run, m.get('source'))
+
     # 1. Peak-integrated data from digitizer
     if useRaw is False:
         # 1.1 No bunch pattern provided
@@ -280,7 +253,11 @@ def get_peaks(run,
             return arr.isel({dim: indices}).rename({dim: extra_dim})
 
         # 1.2 Bunch pattern is provided
-        peak_params = channel_peak_params(run, source, key)
+        if isinstance(mnemonic, dict):
+            peak_params = channel_peak_params(run, mnemonic.get('source'),
+                                              mnemonic.get('key'))
+        else:
+            peak_params = channel_peak_params(run, mnemonic)
         log.debug(f'Digitizer peak integration parameters: {peak_params}')
         return peaks_from_apd(arr, peak_params, digitizer, bpt, bunchPattern)
 
@@ -727,7 +704,7 @@ def check_peak_params(run, mnemonic, raw_trace=None, ntrains=200, params=None,
         log.warning('The digitizer did not record peak-integrated data.')
     if not plot:
         return params
-    digitizer = digitizer_type(run, run_mnemonics[mnemonic]['source'].split(':')[0])
+    digitizer = digitizer_type(run, run_mnemonics[mnemonic]['source'])
     min_distance = 24 if digitizer == "FastADC" else 440
     if 'bunchPatternTable' in run_mnemonics and bunchPattern != 'None':
         sel = run.select_trains(np.s_[:ntrains])
@@ -847,6 +824,7 @@ def digitizer_type(run, source):
                      'PyADCChannelLegacy': 'FastADC'
                     }
         try:
+            source = source.split(':')[0]
             classId = run.get_run_value(source, 'classId.value')
             ret = digi_dict.get(classId)
         except Exception as e:
@@ -940,7 +918,7 @@ def get_laser_peaks(run, mnemonic=None, merge_with=None,
 
 def get_digitizer_peaks(run, mnemonic, merge_with=None,
                         bunchPattern='sase3', integParams=None,
-                        digitizer=None, keepAllSase=False):
+                        keepAllSase=False):
     """
     Automatically computes digitizer peaks. A source can be loaded on the
     fly via the mnemonic argument, or processed from an existing data set
@@ -953,10 +931,10 @@ def get_digitizer_peaks(run, mnemonic, merge_with=None,
         DataCollection containing the digitizer data.
     mnemonic: str
         mnemonic for FastADC or ADQ412, e.g. "I0_ILHraw" or "MCP3apd".
+        The data is either loaded from the DataCollection or taken from
+        merge_with.
     merge_with: xarray Dataset
-        If provided, the resulting Dataset will be merged with this
-        one. The FastADC variables of merge_with (if any) will also be
-        computed and merged.
+        If provided, the resulting Dataset will be merged with this one.
     bunchPattern: str or dict
         'sase1' or 'sase3' or 'scs_ppl', 'None': bunch pattern
     integParams: dict
@@ -1007,16 +985,12 @@ def get_digitizer_peaks(run, mnemonic, merge_with=None,
     # iterate over mnemonics and merge arrays in dataset
     autoFind = True if integParams is None else False
     m = run_mnemonics[mnemonic]
-    digitizer = digitizer_type(run, m['source'].split(':')[0])
     useRaw = True if 'raw' in mnemonic else False
     if bool(merge_with) and mnemonic in merge_with:
         data = merge_with[mnemonic]
     else:
         data = run.get_array(m['source'], m['key'], m['dim'])
-    peaks = get_peaks(run, data,
-                      source=m['source'],
-                      key=m['key'],
-                      digitizer=digitizer,
+    peaks = get_peaks(run, data, mnemonic,
                       useRaw=useRaw,
                       autoFind=autoFind,
                       integParams=integParams,
-- 
GitLab