From 27c9380a851334f3710be44012ea3ef50b5f0e45 Mon Sep 17 00:00:00 2001
From: Laurent Mercadier <laurent.mercadier@xfel.eu>
Date: Sun, 6 Oct 2024 11:54:54 +0200
Subject: [PATCH] update pes_get_tof(), add saving and loading averge traces
 functions

---
 src/toolbox_scs/detectors/pes.py | 290 ++++++++++++++++---------------
 1 file changed, 150 insertions(+), 140 deletions(-)

diff --git a/src/toolbox_scs/detectors/pes.py b/src/toolbox_scs/detectors/pes.py
index be5ee30..33c1e30 100644
--- a/src/toolbox_scs/detectors/pes.py
+++ b/src/toolbox_scs/detectors/pes.py
@@ -11,29 +11,36 @@ import numpy as np
 import xarray as xr
 import extra_data as ed
 import re
+import os
+from pathlib import Path
+from multiprocessing import Pool
 from extra.components import XrayPulses, AdqRawChannel
 
 from ..misc.bunch_pattern_external import is_sase_3
 from ..mnemonics_machinery import (mnemonics_to_process,
                                    mnemonics_for_run)
+from extra_data.read_machinery import find_proposal
 from ..constants import mnemonics as _mnemonics
 
 __all__ = [
     'get_pes_params',
     'get_pes_tof',
-    'extract_pes_spectra',
+    'save_pes_avg_traces',
+    'load_pes_avg_traces'
 ]
 
 
 log = logging.getLogger(__name__)
 
 
-def extract_pes_spectra(proposal, runNB, mnemonic,
-                        start=0, origin=None, width=None):
+def get_pes_tof(proposal, runNB, mnemonic, start=0, origin=None,
+                        width=None, subtract_baseline=False,
+                        baseStart=None, baseWidth=40,merge_with=None):
     """
     Extracts time-of-flight spectra from raw digitizer traces. The spectra
-    are aligned by pulse Id using the SASE 3 bunch pattern, and have time
-    coordinates in nanoseconds.
+    are aligned by pulse Id using the SASE 3 bunch pattern. If origin is
+    not None, a time coordinate in nanoseconds 'time_ns' is computed and
+    added to the DataArray.
 
     Parameters
     ----------
@@ -46,12 +53,24 @@ def extract_pes_spectra(proposal, runNB, mnemonic,
     start: int
         starting sample of the first spectrum in the raw trace.
     origin: int
-        sample of the raw trace that corresponds to time-of-flight origin.
-        Used to compute the 'time_ns' coordinates.
-        If None, computation of 'time_ns' is skipped.
+        sample of the spectrum that corresponds to time-of-flight origin.
+        This is relative to the start, so the position in the raw trace
+        corresponds to start + origin. Used to compute the 'time_ns'
+        coordinates. If None, computation of 'time_ns' is skipped.
     width: int
         number of samples per spectra. If None, the number of samples
         for 4.5 MHz repetition rate is used.
+    subtract_baseline: bool
+        If True, subtract baseline defined by baseStart and baseWidth to each
+        spectrum.
+    baseStart: int
+        starting sample of the baseline.
+    baseWidth: int
+        number of samples to average (starting from baseStart) for baseline
+        calculation.
+    merge_with: xarray Dataset
+        If provided, the resulting Dataset will be merged with this
+        one.
 
     Returns
     -------
@@ -63,7 +82,8 @@ def extract_pes_spectra(proposal, runNB, mnemonic,
     >>> import toolbox_scs as tb
     >>> import toolbox_scs.detectors as tbdet
     >>> proposal, runNB = 900447, 12
-    >>> pes = tbdet.get_pes_tof(proposal, runNB, 'PES_2Araw')
+    >>> pes = tbdet.get_pes_tof(proposal, runNB, 'PES_2Araw',
+    >>>                         start=2557, origin=76)
     """
     run = ed.open_run(proposal, runNB)
     all_mnemonics = mnemonics_for_run(run)
@@ -106,6 +126,11 @@ def extract_pes_spectra(proposal, runNB, mnemonic,
                            coords={'trainId': kd.train_id_coordinates(),
                                    'sa3_pId': pulse_ids[:npulses_trace],
                                    'sampleId': np.arange(period)})
+    if subtract_baseline:
+        if baseStart is None:
+            baseStart = 0
+        spectra = spectra - spectra.isel(
+            sampleId=slice(baseStart, baseStart + baseWidth)).mean('sampleId')
     if width is None:
         width = PULSE_PERIOD
     spectra = spectra.isel(sampleId=slice(0, width), drop=True)
@@ -122,149 +147,27 @@ def extract_pes_spectra(proposal, runNB, mnemonic,
                    - origin)/sample_rate * 1e9
         spectra = spectra.assign_coords(time_ns=('sampleId', time_ns))
         spectra.attrs['origin'] = origin
+    spectra = spectra.rename(mnemonic.replace('raw', 'spectrum'))
+    if merge_with is not None:
+        return merge_with.merge(spectra.to_dataset(promote_attrs=True),
+                                join='inner')
     return spectra.rename(mnemonic.replace('raw', 'spectrum'))
 
 
-def get_pes_tof(run, mnemonics=None, merge_with=None,
-                start=31390, width=300, origin=None, width_ns=None,
-                subtract_baseline=True,
-                baseStart=None, baseWidth=80,
-                sample_rate=2e9):
-    """
-    Extracts time-of-flight spectra from raw digitizer traces. The
-    tracesvare either loaded via ToolBox mnemonics or those in the
-    optionally provided merge_with dataset. The spectra are aligned
-    by pulse Id using the SASE 3 bunch pattern, and have time coordinates
-    in nanoseconds.
-
-    Parameters
-    ----------
-    run: extra_data.DataCollection
-        DataCollection containing the digitizer data
-    mnemonics: str or list of str
-        mnemonics for PES, e.g. "PES_W_raw" or ["PES_W_raw", "PES_ENE_raw"].
-        If None and no merge_with dataset is provided, defaults to "PES_W_raw".
-    merge_with: xarray Dataset
-        If provided, the resulting Dataset will be merged with this
-        one. The PES variables of merge_with (if any) will also be
-        computed and merged.
-    start: int
-        starting sample of the first spectrum in the raw trace.
-    width: int
-        number of samples per spectra.
-    origin: int
-        sample of the raw trace that corresponds to time-of-flight origin.
-        If None, origin is equal to start.
-    width_ns: float
-        time window for one spectrum. If None, the time window is defined by
-        width / sample rate.
-    subtract_baseline: bool
-        If True, subtract baseline defined by baseStart and baseWidth to each
-        spectrum.
-    baseStart: int
-        starting sample of the baseline.
-    baseWidth: int
-        number of samples to average (starting from baseStart) for baseline
-        calculation.
-    sample_rate: float
-        sample rate of the digitizer.
-
-    Returns
-    -------
-    pes: xarray Dataset
-        Dataset containing the PES time-of-flight spectra (e.g. "PES_W_tof"),
-        merged with optionally provided merg_with dataset.
-
-    Example
-    -------
-    >>> import toolbox_scs as tb
-    >>> import toolbox_scs.detectors as tbdet
-    >>> run, ds = tb.load(2927, 100, "PES_W_raw")
-    >>> pes = tbdet.get_pes_tof(run, merge_with=ds)
-    """
-    def to_processed_name(name):
-        return name.replace('raw', 'tof')
-    to_process = mnemonics_to_process(mnemonics, merge_with,
-                                      'PES', to_processed_name)
-    run_mnemonics = mnemonics_for_run(run)
-    # check if bunch pattern table exists
-    if bool(merge_with) and 'bunchPatternTable' in merge_with:
-        bpt = merge_with['bunchPatternTable']
-    elif 'bunchPatternTable' in run_mnemonics:
-        bpt = run.get_array(*run_mnemonics['bunchPatternTable'].values())
-    elif 'bunchPatternTable_SA3' in run_mnemonics:
-        bpt = run.get_array(*run_mnemonics['bunchPatternTable_SA3'].values())
-    else:
-        bpt = None
-
-    mask = is_sase_3(bpt).assign_coords({'pulse_slot': np.arange(2700)})
-    mask_on = mask.where(mask, drop=True)
-    npulses = mask.sum(dim='pulse_slot')[0].values
-    if npulses > 1:
-        period = mask_on['pulse_slot'].diff(dim='pulse_slot')[0].values
-    else:
-        period = 0
-    if origin is None:
-        origin = start
-    if baseStart is None:
-        baseStart = start
-    if width_ns is not None:
-        width = int(sample_rate * width_ns * 1e-9)
-    time_ns = 1e9 * (np.arange(start, start + width) - origin) / sample_rate
-
-    ds = xr.Dataset()
-    for m in to_process:
-        if bool(merge_with) and m in merge_with:
-            arr = merge_with[m]
-        else:
-            arr = run.get_array(*run_mnemonics[m].values(), name=m)
-        if arr.sizes['PESsampleId'] < npulses*period*440 + start + width:
-            log.warning('Not all pulses were recorded. The number of samples '
-                        f'on the digitizer {arr.sizes["PESsampleId"]} is not '
-                        f'enough to cover the {npulses} spectra. Missing '
-                        'spectra will be filled with NaNs.')
-        spectra = []
-        for p in range(npulses):
-            begin = p*period*440 + start
-            end = begin + width
-            if end > arr.sizes['PESsampleId']:
-                break
-            pes = arr.isel(PESsampleId=slice(begin, end))
-            if subtract_baseline:
-                baseBegin = p*period*440 + baseStart
-                baseEnd = baseBegin + baseWidth
-                bl = arr.isel(
-                 PESsampleId=slice(baseBegin, baseEnd)).mean(dim='PESsampleId')
-                pes = pes - bl
-            spectra.append(pes)
-        spectra = xr.concat(spectra,
-                            dim='sa3_pId').rename(m.replace('raw', 'tof'))
-        ds = ds.merge(spectra)
-    if len(ds.variables) > 0:
-        ds = ds.assign_coords(
-             {'sa3_pId': mask_on['pulse_slot'][:ds.sizes['sa3_pId']].values})
-        ds = ds.rename({'PESsampleId': 'time_ns'})
-        ds = ds.assign_coords({'time_ns': time_ns})
-    if bool(merge_with):
-        ds = merge_with.drop(to_process,
-                             errors='ignore').merge(ds, join='left')
-
-    return ds
-
-
 def get_pes_params(run, channel=None):
     """
     Extract PES parameters for a given extra_data DataCollection.
-    Parameters are gas, binding energy, voltages of the MPOD.
+    Parameters are gas, binding energy, retardation voltages or all
+    voltages of the MPOD.
 
     Parameters
     ----------
     run: extra_data.DataCollection
         DataCollection containing the digitizer data
     channel: str
-        Channel name, e.g. '2A'. If None, or if the channel is not
-        found in the data, the retardation voltage for all channels is
-        retrieved.
+        Channel name or PES mnemonic, e.g. '2A' or 'PES_1Craw'.
+        If None, or if the channel is not found in the data,
+        the retardation voltage for all channels is retrieved.
     Returns
     -------
     params: dict
@@ -289,6 +192,8 @@ def get_pes_params(run, channel=None):
             channel = [f'{a//4 + 1}{b}' for a, b in enumerate(
                 ['A', 'B', 'C', 'D']*4)]
         else:
+            if 'raw' in channel:
+                channel = channel.split('raw')[0].split('_')[1]
             channel = [channel]
         for c in channel:
             rv = get_pes_rv(run, c, mpod_mapper)
@@ -342,3 +247,108 @@ def get_pes_voltages(run, device='SA3_XTD10_PES/MDL/DAQ_MPOD'):
         if len(a.findall(k)) == 1:
             voltages[k.split('.')[0]] = da[device][k]
     return voltages
+
+
+def calculate_average(run, name, mnemo):
+    return run[mnemo['source'], mnemo['key']].xarray(
+        name=name, extra_dims=mnemo['dim']).mean('trainId')
+
+
+def save_pes_avg_traces(proposal, runNB, channels=None,
+                        subdir='usr/processed_runs'):
+    '''
+    Save average traces of PES into an h5 file.
+
+    Parameters
+    ----------
+    proposal:int
+        The proposal number.
+    runNB: int
+        The run number.
+    channels: str or list
+        The PES channels or mnemonics, e.g. '2A', ['2A', '3C'],
+        ['PES_1Araw', 'PES_4Draw', '3B']
+    subdir: str
+        subdirectory. The data is stored in
+        <proposal path>/<subdir>/r{runNB:04d}/f'r{runNB:04d}-pes-data.h5'
+
+    Output
+    ------
+    xarray Dataset saved in a h5 file containing the PES average traces.
+    '''
+    root = find_proposal(f'p{proposal:06d}')
+    path = os.path.join(root, subdir + f'/r{runNB:04d}/')
+    Path(path).mkdir(parents=True, exist_ok=True)
+    fname = path + f'r{runNB:04d}-pes-data.h5'
+    if channels is None:
+        channels = [f'{a//4 + 1}{b}' for a, b in enumerate(
+            ['A', 'B', 'C', 'D']*4)]
+    if isinstance(channels, str):
+        channels = [channels]
+    run = ed.open_run(proposal, runNB, parallelize=False)
+    all_mnemos = mnemonics_for_run(run)
+    # use multiprocessing.Pool
+    args = []
+    for c in channels:
+        m = c
+        if 'raw' not in c:
+            m = f'PES_{c}raw'
+        if m not in all_mnemos:
+            continue
+        args.append([run, m.replace('raw', 'avg'), all_mnemos[m]])
+    if len(args) == 0:
+        log.warning('No pes average trace to save. Skipping')
+        return
+    with Pool(len(args)) as pool:
+        avg_traces = pool.starmap(calculate_average, args)
+    avg_traces = xr.merge(avg_traces)
+    ds = xr.Dataset()
+    if os.path.isfile(fname):
+        ds = xr.load_dataset(fname)
+        ds = ds.drop_vars(channels, errors='ignore')
+    ds = ds.merge(avg_traces)
+    ds.to_netcdf(fname, format='NETCDF4')
+    return
+
+
+def load_pes_avg_traces(proposal, runNB, channels=None,
+                        subdir='usr/processed_runs'):
+    '''
+    Load existing PES average traces.
+
+    Parameters
+    ----------
+    proposal:int
+        The proposal number.
+    runNB: int
+        The run number.
+    channels: str or list
+        The PES channels or mnemonics, e.g. '2A', ['2A', '3C'],
+        ['PES_1Araw', 'PES_4Draw', '3B']
+    subdir: str
+        subdirectory. The data is stored in
+        <proposal path>/<subdir>/r{runNB:04d}/f'r{runNB:04d}-pes-data.h5'
+
+    Output
+    ------
+    ds: xarray Dataset
+        dataset containing the PES average traces.
+    '''
+    root = find_proposal(f'p{proposal:06d}')
+    path = os.path.join(root, subdir + f'/r{runNB:04d}/')
+    fname = path + f'r{runNB:04d}-pes-data.h5'
+    if channels is None:
+        channels = [f'PES_{a//4 + 1}{b}avg' for a, b in
+                    enumerate(['A', 'B', 'C', 'D']*4)]
+    if isinstance(channels, str):
+        channels = [channels]
+    for i, c in enumerate(channels):
+        if 'PES_' not in c and 'avg' not in c:
+            channels[i] = f'PES_{c}avg'
+    if os.path.isfile(fname):
+        ds = xr.load_dataset(fname)
+        channels = [c for c in channels if c in ds]
+        ds = ds[channels]
+        return ds
+    else:
+        log.warning(f'{fname} is not a valid file.')
-- 
GitLab