From 838e7402d1f165bec8373f885b8873a9c1124bf4 Mon Sep 17 00:00:00 2001
From: Laurent Mercadier <laurent.mercadier@xfel.eu>
Date: Tue, 27 Aug 2024 07:27:36 +0200
Subject: [PATCH] Update get_bam() to process average data

---
 src/toolbox_scs/detectors/bam_detectors.py | 58 ++++++++++++----------
 1 file changed, 33 insertions(+), 25 deletions(-)

diff --git a/src/toolbox_scs/detectors/bam_detectors.py b/src/toolbox_scs/detectors/bam_detectors.py
index 39f2984..70226f4 100644
--- a/src/toolbox_scs/detectors/bam_detectors.py
+++ b/src/toolbox_scs/detectors/bam_detectors.py
@@ -39,12 +39,12 @@ def get_bam(run, mnemonics=None, merge_with=None, bunchPattern='sase3',
         DataCollection containing the bam data.
     mnemonics: str or list of str
         mnemonics for BAM, e.g. "BAM1932M" or ["BAM414", "BAM1932M"].
-        If None, defaults to "BAM1932M" in case no merge_with dataset
-        is provided.
+        the arrays are either taken from merge_with or loaded from
+        the DataCollection run.
     merge_with: xarray Dataset
         If provided, the resulting Dataset will be merged with this
-        one. The BAM variables of merge_with (if any) will also be
-        selected, aligned and merged.
+        one. If merge_with contains variables in mnemonics, they will
+        be selected, aligned and merged.
     bunchPattern: str
         'sase1' or 'sase3' or 'scs_ppl', bunch pattern
         used to extract peaks. The pulse ID dimension will be named
@@ -68,20 +68,20 @@ def get_bam(run, mnemonics=None, merge_with=None, bunchPattern='sase3',
         mnemonics = [mnemonics] if isinstance(mnemonics, str) else mnemonics
         for m in mnemonics:
             if any([(k in m) for k in bam_mnemos]):
-                if merge_with is not None and m in merge_with:
-                    continue
+                #if merge_with is not None and m in merge_with:
+                #    continue
                 m2.append(m)
-    if merge_with is not None:
-        in_mw = []
-        for m, da in merge_with.items():
-            if any([(k in m) for k in bam_mnemos]) and 'BAMbunchId' in da.dims:
-                in_mw.append(m)
-        m2 += in_mw
-
-    if len(m2) == 0:
+    #if merge_with is not None:
+    #    in_mw = []
+    #    for m, da in merge_with.items():
+    #        if any([(k in m) for k in bam_mnemos]) and 'BAMbunchId' in da.dims:
+    #            in_mw.append(m)
+    #    m2 += in_mw
+
+    mnemonics = list(set(m2))
+    if len(mnemonics) == 0:
         log.info('no BAM mnemonics to process. Skipping.')
         return merge_with
-    mnemonics = list(set(m2))
     # Prepare the dataset of non-BAM data to merge with
     if bool(merge_with):
         ds_mw = merge_with.drop(mnemonics, errors='ignore')
@@ -101,16 +101,24 @@ def get_bam(run, mnemonics=None, merge_with=None, bunchPattern='sase3',
             da_bam = merge_with[m]
         else:
             da_bam = get_array(run, m)
-        da_bam = da_bam.sel(BAMbunchId=slice(0, None, 2))
-        # align the pulse Id
-        if bpt is not None:
-            n = mask.sizes[dim_names[bunchPattern]]
-            da_bam = da_bam.isel(BAMbunchId=slice(0, n))
-            da_bam = da_bam.assign_coords(BAMbunchId=np.arange(0, n))
-            da_bam = da_bam.rename(BAMbunchId=dim_names[bunchPattern])
-            da_bam = da_bam.where(mask, drop=True)
-        if run_mnemonics[m]['key'] != 'data.lowChargeArrivalTime':
-            da_bam *= 1e-3
+        if len(da_bam.dims) == 2:
+            da_bam = da_bam.sel(BAMbunchId=slice(0, None, 2))
+            # align the pulse Id
+            if bpt is not None:
+                n = mask.sizes[dim_names[bunchPattern]]
+                da_bam = da_bam.isel(BAMbunchId=slice(0, n))
+                da_bam = da_bam.assign_coords(BAMbunchId=np.arange(0, n))
+                da_bam = da_bam.rename(BAMbunchId=dim_names[bunchPattern])
+                da_bam = da_bam.where(mask, drop=True)
+            # make sure unit is picosecond
+            if run_mnemonics[m]['key'] != 'data.lowChargeArrivalTime':
+                da_bam *= 1e-3
+        else:
+            # The 1D values (mean, std dev...) are in fs, need to convert to ps
+            mnemo = run_mnemonics[m]
+            first_val = run[mnemo['source']][mnemo['key']].train_from_index(0)[1]
+            if first_val == da_bam[0]:
+                da_bam *= 1e-3
         ds = ds.merge(da_bam, join='inner')
     # merge with non-BAM dataset
     ds = ds_mw.merge(ds, join='inner')
-- 
GitLab