diff --git a/src/toolbox_scs/detectors/digitizers.py b/src/toolbox_scs/detectors/digitizers.py
index 30e963ece4fa1bfe3c0fe33f9c22c0f9036250de..3168aa3d9ac3874d4754660f2aa545131b4ba0c3 100644
--- a/src/toolbox_scs/detectors/digitizers.py
+++ b/src/toolbox_scs/detectors/digitizers.py
@@ -318,6 +318,74 @@ def get_peaks(run,
     return peaks.assign_coords({extra_dim: pid})
 
 
+def get_peaks_from_list(run, dataList, digitizer='ADQ412',
+                        bunchPattern='sase3', integParams=None,
+                        keepAllSase=False):
+    """
+    Automatically loads and computes digitizer peaks from ToolBox
+    mnemonics and/or data arrays (raw traces or peak-integrated values).
+    Sources can be raw traces (e.g. "MCP3raw") or peak-
+    integrated data (e.g. "MCP2apd"). The bunch pattern table is
+    used to assign the pulse id coordinates.
+
+    Parameters
+    ----------
+    run: extra_data.DataCollection
+        DataCollection containing the digitizer data. If provided,
+        data must be a str or list of str. If None, data must be
+        an xarray Dataset containing actual data.
+    key: list of str and/or DataArray
+        A mnemonics for TIM, e.g. "MCP2apd" or "FastADC5raw".
+    bunchPattern: str
+        'sase1' or 'sase3' or 'scs_ppl', bunch pattern
+        used to extract peaks. The pulse ID dimension will be named
+        'sa1_pId', 'sa3_pId' or 'ol_pId', respectively.
+    integParams: dict
+        dictionnary for raw trace integration, e.g.
+        {'intstart':100, 'intstop':200, 'bkgstart':50,
+        'bkgstop':99, 'period':440, 'npulses':500}.
+        If None, integration parameters are computed automatically.
+    keepAllSase: bool
+        Only relevant in case of sase-dedicated trains. If True, all
+        trains are kept, else only those of the bunchPattern are kept.
+
+    Returns
+    -------
+    xarray Dataset with all TIM variables substituted by
+    the peak caclulated values (e.g. "MCP2raw" becomes
+    "MCP2peaks").
+    """
+    autoFind = True
+    if integParams is not None:
+        autoFind = False
+    keys = []
+    vals = []
+    for d in dataList:
+        if isinstance(d, str):
+            key = d
+            log.debug(f'Loading peaks from {key}...')
+        if isinstance(d, xr.DataArray):
+            key = d.name
+            log.debug(f'Retrieving peaks from {key} array...')
+        useRaw = True if 'raw' in key else False
+        m = _mnemonics[key]
+        peaks = get_peaks(run, d,
+                          source=m['source'],
+                          key=m['key'],
+                          digitizer=digitizer,
+                          useRaw=useRaw,
+                          autoFind=autoFind,
+                          integParams=integParams,
+                          bunchPattern=bunchPattern)
+        key = key.replace('raw', 'peaks').replace('apd', 'peaks')
+        keys.append(key)
+        vals.append(peaks)
+    join = 'outer' if keepAllSase else 'inner'
+    aligned_vals = xr.align(*vals, join=join)
+    ds = xr.Dataset(dict(zip(keys, aligned_vals)))
+    return ds
+
+
 def channel_peak_params(run, source, key=None, digitizer=None,
                         channel=None, board=None):
     """
@@ -458,26 +526,6 @@ def find_integ_params(trace, min_distance=1, height=None, width=1):
     return result
 
 
-def plotPeakIntegrationWindow(raw_trace, params):
-    xmin = np.max([0, params['baseStart']-100])
-    xmax = np.min([params['pulseStop']+100, raw_trace.size])
-    fig, ax = plt.subplots(figsize=(5, 3))
-    ax.axvline(params['baseStart'], ls='--', color='k')
-    ax.axvline(params['baseStop'], ls='--', color='k')
-    ax.axvspan(params['baseStart'], params['baseStop'],
-               alpha=0.5, color='grey', label='baseline')
-    ax.axvline(params['pulseStart'], ls='--', color='r')
-    ax.axvline(params['pulseStop'], ls='--', color='r')
-    ax.axvspan(params['pulseStart'], params['pulseStop'],
-               alpha=0.2, color='r', label='peak')
-    ax.plot(np.arange(xmin, xmax), raw_trace[xmin:xmax],
-            label='raw trace')
-    ax.set_xlabel('digitizer samples')
-    ax.legend(fontsize=8)
-    fig.tight_layout()
-    return ax
-
-
 def get_peak_params(run, key, raw_trace=None, plot=False):
     """
     Get the peak region and baseline region of a raw
@@ -530,72 +578,24 @@ def get_peak_params(run, key, raw_trace=None, plot=False):
     return params
 
 
-def get_peaks_from_list(run, dataList, digitizer='ADQ412',
-                        bunchPattern='sase3', integParams=None,
-                        keepAllSase=False):
-    """
-    Automatically loads and computes digitizer peaks from ToolBox
-    mnemonics and/or data arrays (raw traces or peak-integrated values).
-    Sources can be raw traces (e.g. "MCP3raw") or peak-
-    integrated data (e.g. "MCP2apd"). The bunch pattern table is
-    used to assign the pulse id coordinates.
-
-    Parameters
-    ----------
-    run: extra_data.DataCollection
-        DataCollection containing the digitizer data. If provided,
-        data must be a str or list of str. If None, data must be
-        an xarray Dataset containing actual data.
-    key: list of str and/or DataArray
-        A mnemonics for TIM, e.g. "MCP2apd" or "FastADC5raw".
-    bunchPattern: str
-        'sase1' or 'sase3' or 'scs_ppl', bunch pattern
-        used to extract peaks. The pulse ID dimension will be named
-        'sa1_pId', 'sa3_pId' or 'ol_pId', respectively.
-    integParams: dict
-        dictionnary for raw trace integration, e.g.
-        {'intstart':100, 'intstop':200, 'bkgstart':50,
-        'bkgstop':99, 'period':440, 'npulses':500}.
-        If None, integration parameters are computed automatically.
-    keepAllSase: bool
-        Only relevant in case of sase-dedicated trains. If True, all
-        trains are kept, else only those of the bunchPattern are kept.
-
-    Returns
-    -------
-    xarray Dataset with all TIM variables substituted by
-    the peak caclulated values (e.g. "MCP2raw" becomes
-    "MCP2peaks").
-    """
-    autoFind = True
-    if integParams is not None:
-        autoFind = False
-    keys = []
-    vals = []
-    for d in dataList:
-        if isinstance(d, str):
-            key = d
-            log.debug(f'Loading peaks from {key}...')
-        if isinstance(d, xr.DataArray):
-            key = d.name
-            log.debug(f'Retrieving peaks from {key} array...')
-        useRaw = True if 'raw' in key else False
-        m = _mnemonics[key]
-        peaks = get_peaks(run, d,
-                          source=m['source'],
-                          key=m['key'],
-                          digitizer=digitizer,
-                          useRaw=useRaw,
-                          autoFind=autoFind,
-                          integParams=integParams,
-                          bunchPattern=bunchPattern)
-        key = key.replace('raw', 'peaks').replace('apd', 'peaks')
-        keys.append(key)
-        vals.append(peaks)
-    join = 'outer' if keepAllSase else 'inner'
-    aligned_vals = xr.align(*vals, join=join)
-    ds = xr.Dataset(dict(zip(keys, aligned_vals)))
-    return ds
+def plotPeakIntegrationWindow(raw_trace, params):
+    xmin = np.max([0, params['baseStart']-100])
+    xmax = np.min([params['pulseStop']+100, raw_trace.size])
+    fig, ax = plt.subplots(figsize=(5, 3))
+    ax.axvline(params['baseStart'], ls='--', color='k')
+    ax.axvline(params['baseStop'], ls='--', color='k')
+    ax.axvspan(params['baseStart'], params['baseStop'],
+               alpha=0.5, color='grey', label='baseline')
+    ax.axvline(params['pulseStart'], ls='--', color='r')
+    ax.axvline(params['pulseStop'], ls='--', color='r')
+    ax.axvspan(params['pulseStart'], params['pulseStop'],
+               alpha=0.2, color='r', label='peak')
+    ax.plot(np.arange(xmin, xmax), raw_trace[xmin:xmax],
+            label='raw trace')
+    ax.set_xlabel('digitizer samples')
+    ax.legend(fontsize=8)
+    fig.tight_layout()
+    return ax
 
 
 def get_tim_peaks(run, key=None, merge_with=None,
diff --git a/src/toolbox_scs/detectors/xgm.py b/src/toolbox_scs/detectors/xgm.py
index 309f4a9f7020a7e4440dbf0df22cef6664570187..cde1950450494f29cd4bd4dd0e7d1acd2a10ec8c 100644
--- a/src/toolbox_scs/detectors/xgm.py
+++ b/src/toolbox_scs/detectors/xgm.py
@@ -1,8 +1,8 @@
 """ XGM related sub-routines
 
     Copyright (2019) SCS Team.
-    
-    (contributions preferrably comply with pep8 code structure 
+
+    (contributions preferrably comply with pep8 code structure
     guidelines.)
 """
 
@@ -15,7 +15,7 @@ from scipy.signal import find_peaks
 import extra_data as ed
 
 from ..constants import mnemonics as _mnemonics
-from ..misc.bunch_pattern_external import is_sase_1, is_sase_3, is_ppl
+from ..misc.bunch_pattern_external import is_sase_1, is_sase_3
 
 
 log = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ def load_xgm(run, xgm_mnemonic='SCS_SA3'):
 
 
 def get_xgm(run, key=None, merge_with=None, keepAllSase=False,
-            indices=slice(0,None)):
+            indices=slice(0, None)):
     """
     Load and/or computes XGM data. Sources can be loaded on the
     fly via the key argument, or processed from an existing data set
@@ -84,14 +84,14 @@ def get_xgm(run, key=None, merge_with=None, keepAllSase=False,
     -------
     xarray Dataset with pulse-resolved XGM variables aligned,
      merged with Dataset *merge_with* if provided.
-    
+
     Example
     -------
     >>> import extra_data as ed
     >>> import toolbox_scs.detectors as tbdet
     >>> run = ed.open_run(2212, 213)
     >>> xgm = tbdet.get_xgm(run)
-    
+
     """
     # check if bunch pattern table exists
     if bool(merge_with) and 'bunchPatternTable' in merge_with:
@@ -102,42 +102,121 @@ def get_xgm(run, key=None, merge_with=None, keepAllSase=False,
         log.debug('Loaded bpt from extra_data run.')
     else:
         bpt = None
+
     if key is None and merge_with is None:
-        key = 'SCS_SA3'
+        keys = ['SCS_SA3']
     if key is None:
         keys = []
     else:
         keys = key if isinstance(key, list) else [key]
+
+    # create a list of arrays to sort from the provided keys and merge_with
+    dataList = []
+    for k in keys:
+        xgm = run.get_array(*_mnemonics[k].values())
+        dataList.append(xgm.rename(k))
     if merge_with is None:
-        vals = []
-        for key in keys:
-            xgm = run.get_array(*_mnemonics[key].values())
-            vals.append(xgm)
-        aligned_vals = xr.align(*vals, join='inner')
-        ds = xr.Dataset(dict(zip(keys, aligned_vals)))
-        return cleanXGMdata(ds, bpt=bpt, keepAllSase=keepAllSase, 
-                            indices=indices)
-    mw_xgm_keys = [k for k in merge_with 
-                   if 'XGMbunchId' in merge_with[k].dims]
-    if len(mw_xgm_keys) == 0 and len(keys) == 0:
-        keys = ['SCS_SA3']
-    keys = [key for key in keys if key not in merge_with]
-    if len(keys) == 0:
-        return cleanXGMdata(merge_with, bpt=bpt,
-                            keepAllSase=keepAllSase, indices=indices)
-    vals = []
-    for key in keys:
-        xgm = run.get_array(*_mnemonics[key].values())
-        vals.append(xgm)
-    aligned_vals = xr.align(*vals, join='inner')
-    ds = xr.Dataset(dict(zip(keys, aligned_vals)))
-    ds = ds.merge(merge_with, join='inner')
-    return cleanXGMdata(ds, bpt=bpt, keepAllSase=keepAllSase, 
-                        indices=indices)
-
-
-def cleanXGMdata(data, bpt=None, keepAllSase=False, 
-                 indices=slice(0,None)):
+        mw_ds = xr.Dataset()
+    else:
+        mw_xgm_keys = [k for k in merge_with
+                       if 'XGMbunchId' in merge_with[k].dims]
+        dataList += [merge_with[k] for k in mw_xgm_keys]
+        mw_ds = merge_with.drop(mw_xgm_keys)
+
+    # assign pulse ID to each array in dataList and merge
+    ds = mw_ds
+    for d in dataList:
+        if bpt is not None:
+            arr = align_xgm_array(d, bpt)
+        else:
+            arr = d.where(d != 1., drop=True).sel(XGMbunchId=indices)
+        ds = ds.merge(arr, join='inner')
+    return ds
+
+
+def align_xgm_array(xgm_arr, bpt):
+    """
+    Assigns pulse ID coordinates to a pulse-resolved XGM array, according to
+    the bunch pattern table. If the arrays contains both SASE 1 and SASE 3
+    data, it is split in two arrays.
+
+    Parameters
+    ----------
+    xgm_arr: xarray DataArray
+        array containing pulse-resolved XGM data, with dims ['trainId',
+        'XGMbunchId']
+    bpt: xarray DataArray
+        bunch pattern table
+
+    Returns
+    -------
+    xarray Dataset with pulse ID coordinates. For SASE 1 data, the coordinates
+    name is sa1_pId, for SASE 3 data, the coordinates name is sa3_pId.
+    """
+    key = xgm_arr.name
+    compute_sa1 = False
+    compute_sa3 = False
+    # get the relevant masks for SASE 1 and/or SASE3
+    if "SA1" in key or "SA3" in key:
+        if "SA1" in key:
+            mask = is_sase_1(bpt.sel(trainId=xgm_arr.trainId))
+            compute_sa1 = True
+        else:
+            mask = is_sase_3(bpt.sel(trainId=xgm_arr.trainId))
+            compute_sa3 = True
+        tid = mask.where(mask.sum(dim='pulse_slot') > 0, drop=True).trainId
+        mask = mask.sel(trainId=tid)
+        mask_sa1 = mask.rename({'pulse_slot': 'sa1_pId'})
+        mask_sa3 = mask.rename({'pulse_slot': 'sa3_pId'})
+    if "XGM" in key:
+        compute_sa1 = True
+        compute_sa3 = True
+        mask_sa1 = is_sase_1(bpt.sel(trainId=xgm_arr.trainId))
+        mask_sa3 = is_sase_3(bpt.sel(trainId=xgm_arr.trainId))
+        mask = xr.ufuncs.logical_or(mask_sa1, mask_sa3)
+        tid = mask.where(mask.sum(dim='pulse_slot') > 0,
+                         drop=True).trainId
+        mask_sa1 = mask_sa1.sel(trainId=tid).rename({'pulse_slot': 'sa1_pId'})
+        mask_sa3 = mask_sa3.sel(trainId=tid).rename({'pulse_slot': 'sa3_pId'})
+        mask = mask.sel(trainId=tid)
+
+    npulses_max = mask.sum(dim='pulse_slot').max().values
+    xgm_arr = xgm_arr.sel(trainId=tid).isel(
+                XGMbunchId=slice(0, npulses_max))
+    # pad the xgm array to match the bpt dims, flatten and
+    # reorder xgm array to match the indices of the mask
+    xgm_flat = np.hstack((xgm_arr.fillna(1.),
+                          np.ones((xgm_arr.sizes['trainId'],
+                                   2700-npulses_max)))).flatten()
+    xgm_flat_arg = np.argwhere(xgm_flat != 1.0)
+    mask_flat = mask.values.flatten()
+    mask_flat_arg = np.argwhere(mask_flat)
+    if(xgm_flat_arg.shape != mask_flat_arg.shape):
+        log.warning(f'{key}: XGM data and bunch pattern do not match.')
+    new_xgm_flat = np.ones(xgm_flat.shape)
+    new_xgm_flat[mask_flat_arg] = xgm_flat[xgm_flat_arg]
+    new_xgm = new_xgm_flat.reshape((xgm_arr.sizes['trainId'], 2700))
+
+    # create a dataset with new_xgm array masked by SASE 1 or SASE 3
+    xgm_dict = {}
+    if compute_sa1:
+        sa1_xgm = xr.DataArray(new_xgm, dims=['trainId', 'sa1_pId'],
+                               coords={'trainId': xgm_arr.trainId,
+                                       'sa1_pId': np.arange(2700)},
+                               name=key.replace('XGM', 'SA1'))
+        xgm_dict[sa1_xgm.name] = sa1_xgm.where(mask_sa1, drop=True)
+    if compute_sa3:
+        sa3_xgm = xr.DataArray(new_xgm, dims=['trainId', 'sa3_pId'],
+                               coords={'trainId': xgm_arr.trainId,
+                                       'sa3_pId': np.arange(2700)},
+                               name=key.replace('XGM', 'SA3'))
+        xgm_dict[sa3_xgm.name] = sa3_xgm.where(mask_sa3, drop=True)
+    ds = xr.Dataset(xgm_dict)
+    return ds
+
+
+def cleanXGMdata(data, bpt=None, keepAllSase=False,
+                 indices=slice(0, None)):
     """ 
     Cleans the XGM data arrays obtained from load() function.
     The XGM "TD" data arrays have arbitrary size of 1000 and default value 1.0