diff --git a/src/toolbox_scs/detectors/digitizers.py b/src/toolbox_scs/detectors/digitizers.py
index 5d917bb65ff7a5aee52ce0e67f9eadb6a76fc462..30e963ece4fa1bfe3c0fe33f9c22c0f9036250de 100644
--- a/src/toolbox_scs/detectors/digitizers.py
+++ b/src/toolbox_scs/detectors/digitizers.py
@@ -164,7 +164,7 @@ def get_peaks(run,
               indices=None,
               ):
     """
-    Extract peaks from digitizer data.
+    Extract peaks from one source (channel) of a digitizer.
 
     Parameters
     ----------
@@ -530,12 +530,13 @@ def get_peak_params(run, key, raw_trace=None, plot=False):
     return params
 
 
-def digitizer_peaks_from_keys(run, key, digitizer='ADQ412',
-                              bunchPattern='sase3', integParams=None,
-                              keepAllSase=False):
+def get_peaks_from_list(run, dataList, digitizer='ADQ412',
+                        bunchPattern='sase3', integParams=None,
+                        keepAllSase=False):
     """
     Automatically loads and computes digitizer peaks from ToolBox
-    mnemonics. Sources can be raw traces (e.g. "MCP3raw") or peak-
+    mnemonics and/or data arrays (raw traces or peak-integrated values).
+    Sources can be raw traces (e.g. "MCP3raw") or peak-
     integrated data (e.g. "MCP2apd"). The bunch pattern table is
     used to assign the pulse id coordinates.
 
@@ -545,7 +546,7 @@ def digitizer_peaks_from_keys(run, key, digitizer='ADQ412',
         DataCollection containing the digitizer data. If provided,
         data must be a str or list of str. If None, data must be
         an xarray Dataset containing actual data.
-    key: str or list of str
+    key: list of str and/or DataArray
         A mnemonics for TIM, e.g. "MCP2apd" or "FastADC5raw".
     bunchPattern: str
         'sase1' or 'sase3' or 'scs_ppl', bunch pattern
@@ -566,16 +567,21 @@ def digitizer_peaks_from_keys(run, key, digitizer='ADQ412',
     the peak caclulated values (e.g. "MCP2raw" becomes
     "MCP2peaks").
     """
-    keys = key if isinstance(key, list) else [key]
     autoFind = True
     if integParams is not None:
         autoFind = False
-    peakDict = {}
-    for key in keys:
+    keys = []
+    vals = []
+    for d in dataList:
+        if isinstance(d, str):
+            key = d
+            log.debug(f'Loading peaks from {key}...')
+        if isinstance(d, xr.DataArray):
+            key = d.name
+            log.debug(f'Retrieving peaks from {key} array...')
         useRaw = True if 'raw' in key else False
-        log.debug(f'Retrieving peaks from {key}...')
         m = _mnemonics[key]
-        peaks = get_peaks(run, key,
+        peaks = get_peaks(run, d,
                           source=m['source'],
                           key=m['key'],
                           digitizer=digitizer,
@@ -583,11 +589,12 @@ def digitizer_peaks_from_keys(run, key, digitizer='ADQ412',
                           autoFind=autoFind,
                           integParams=integParams,
                           bunchPattern=bunchPattern)
-        newKey = key.replace('raw', 'peaks').replace('apd', 'peaks')
-        peakDict[newKey] = peaks
-    ds = xr.Dataset()
+        key = key.replace('raw', 'peaks').replace('apd', 'peaks')
+        keys.append(key)
+        vals.append(peaks)
     join = 'outer' if keepAllSase else 'inner'
-    ds = ds.merge(peakDict, join=join)
+    aligned_vals = xr.align(*vals, join=join)
+    ds = xr.Dataset(dict(zip(keys, aligned_vals)))
     return ds
 
 
@@ -734,13 +741,15 @@ def get_digitizer_peaks(run, key, digitizer, merge_with,
         dig_keys += [f'MCP{c}peaks' for c in range(1, 5)]
     if merge_with is None:
         if key is None:
-            key = default_key
+            keys = [default_key]
+        else:
+            keys = key if isinstance(key, list) else [key]
         log.info(f'Loading peaks from {key} data.')
-        return digitizer_peaks_from_keys(run, key,
-                                         digitizer=digitizer,
-                                         bunchPattern=bunchPattern,
-                                         integParams=integParams,
-                                         keepAllSase=keepAllSase)
+        return get_peaks_from_list(run, keys,
+                                   digitizer=digitizer,
+                                   bunchPattern=bunchPattern,
+                                   integParams=integParams,
+                                   keepAllSase=keepAllSase)
     mw_keys = [k for k in merge_with if k in dig_keys]
     mw_to_process = [k for k in mw_keys if len(merge_with[k].coords) != 2]
     mw_channels = [re.split(r'(\d+)', k)[1] for k in mw_keys]
@@ -748,7 +757,6 @@ def get_digitizer_peaks(run, key, digitizer, merge_with,
         key = default_key
     if key is None:
         keys = []
-        ds_dig = xr.Dataset()
     else:
         keys = key if isinstance(key, list) else [key]
         duplicate = [k for k in keys if
@@ -757,31 +765,15 @@ def get_digitizer_peaks(run, key, digitizer, merge_with,
             log.info(f'{duplicate} keys already in merge_with dataset. ' +
                      'Skipping.')
         keys = [k for k in keys if k not in duplicate]
-        ds_dig = digitizer_peaks_from_keys(run, keys,
-                                           digitizer=digitizer,
-                                           bunchPattern=bunchPattern,
-                                           integParams=integParams,
-                                           keepAllSase=keepAllSase)
     log.info(f'Loading peaks from {keys+mw_to_process}.')
-    autoFind = True
-    if integParams is not None:
-        autoFind = False
-    peakDict = {}
-    toRemove = []
-    for key in mw_to_process:
-        useRaw = True if 'raw' in key else False
-        log.debug(f'Retrieving peaks from {key}...')
-        m = _mnemonics[key]
-        peaks = get_peaks(run, merge_with[key], source=m['source'],
-                          key=m['key'], digitizer=digitizer,
-                          useRaw=useRaw, autoFind=autoFind,
-                          integParams=integParams,
-                          bunchPattern=bunchPattern)
-        newKey = key.replace('raw', 'peaks').replace('apd', 'peaks')
-        peakDict[newKey] = peaks
-        toRemove.append(key)
-    ds = merge_with.drop(toRemove)
+    dataList = keys
+    dataList += [merge_with[k] for k in mw_to_process]
+    ds_dig = get_peaks_from_list(run, dataList,
+                                 digitizer=digitizer,
+                                 bunchPattern=bunchPattern,
+                                 integParams=integParams,
+                                 keepAllSase=keepAllSase)
+    ds = merge_with.drop(mw_to_process)
     join = 'outer' if keepAllSase else 'inner'
-    ds = ds.merge(peakDict, join=join)
-    ds = ds_dig.merge(ds, join=join)
+    ds = ds.merge(ds_dig, join=join)
     return ds
diff --git a/src/toolbox_scs/detectors/xgm.py b/src/toolbox_scs/detectors/xgm.py
index 32f94f44f736db2d830469b9c174d3fb4ce71c87..309f4a9f7020a7e4440dbf0df22cef6664570187 100644
--- a/src/toolbox_scs/detectors/xgm.py
+++ b/src/toolbox_scs/detectors/xgm.py
@@ -94,8 +94,12 @@ def get_xgm(run, key=None, merge_with=None, keepAllSase=False,
     
     """
     # check if bunch pattern table exists
-    if _mnemonics['bunchPatternTable']['source'] in run.all_sources:
+    if bool(merge_with) and 'bunchPatternTable' in merge_with:
+        bpt = merge_with['bunchPatternTable']
+        log.debug('Using bpt from merge_with dataset.')
+    elif _mnemonics['bunchPatternTable']['source'] in run.all_sources:
         bpt = run.get_array(*_mnemonics['bunchPatternTable'].values())
+        log.debug('Loaded bpt from extra_data run.')
     else:
         bpt = None
     if key is None and merge_with is None:
@@ -307,12 +311,11 @@ def selectSASEinXGM(data, sase='sase3', xgm='SCS_XGM',
         return xgmData#.assign_coords(XGMbunchId=np.arange(npulses))
     
     #2. case where bunch pattern is provided
-    mask_sa1 = is_sase_1(bpt)
-    mask_sa3 = is_sase_3(bpt)
+    mask_sa1 = is_sase_1(bpt).sel(trainId=data.trainId)
+    mask_sa3 = is_sase_3(bpt).sel(trainId=data.trainId)
     mask_all = xr.ufuncs.logical_or(mask_sa1, mask_sa3)
     tid = mask_all.where(mask_all.sum(dim='pulse_slot')>0,
                                drop=True).trainId
-    tid = np.intersect1d(tid, data.trainId)
     mask_sa1 = mask_sa1.sel(trainId=tid)
     mask_sa3 = mask_sa3.sel(trainId=tid)
     mask_all = mask_all.sel(trainId=tid)