From dc01485b6f9d8bbe3f2458ab1eac8fcbee29c3ed Mon Sep 17 00:00:00 2001 From: Laurent Mercadier <laurent.mercadier@xfel.eu> Date: Thu, 14 Nov 2024 20:51:26 +0100 Subject: [PATCH] Improve extract_GH2() --- src/toolbox_scs/detectors/gotthard2.py | 37 +++++++++++++++++--------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/src/toolbox_scs/detectors/gotthard2.py b/src/toolbox_scs/detectors/gotthard2.py index b420aac..bf2988c 100644 --- a/src/toolbox_scs/detectors/gotthard2.py +++ b/src/toolbox_scs/detectors/gotthard2.py @@ -15,9 +15,13 @@ __all__ = [ log = logging.getLogger(__name__) -def extract_GH2(ds, run, firstFrame=0, bunchPattern='scs_ppl'): + +def extract_GH2(ds, run, firstFrame=0, bunchPattern='scs_ppl', + gh2_dim='gh2_pId'): ''' - Extract the frames of the Gotthard-II that have been exposed to light. + Select and align the frames of the Gotthard-II that have been exposed + to light. + Parameters ------ ds: xarray.Dataset @@ -30,33 +34,42 @@ def extract_GH2(ds, run, firstFrame=0, bunchPattern='scs_ppl'): the bunch pattern used to align data. For 'scs_ppl', the gh2_pId dimension in renamed 'ol_pId', and for 'sase3' gh2_pId is renamed 'sa3_pId'. + gh2_dim: str + The name of the dimension that corresponds to the Gotthard-II frames. Returns ------- nds: xarray Dataset - The aligned and reduced dataset with only-data-containing GH2 variables. + The aligned and reduced dataset with only-data-containing GH2 + variables. ''' + if gh2_dim not in ds.dims: + log.warning(f'gh2_dim "{gh2_dim}" not in dataset. Skipping.') + return ds if bunchPattern == 'scs_ppl': pattern = OpticalLaserPulses(run) - dim='ol_pId' + dim = 'ol_pId' else: pattern = XrayPulses(run) - dim='sa3_pId' + dim = 'sa3_pId' + others = [var for var in ds if dim in ds[var].coords] + nds = ds.drop_dims(dim) if pattern.is_constant_pattern(): pulse_ids = pattern.peek_pulse_ids(labelled=False) - nds = ds.isel(gh2_pId=pulse_ids + firstFrame) - nds = nds.assign_coords(gh2_pId=pulse_ids) - nds = nds.rename(gh2_pId=dim) + nds = nds.isel({gh2_dim: pulse_ids + firstFrame}) + nds = nds.assign_coords({gh2_dim: pulse_ids}) + nds = nds.rename({gh2_dim: dim}) else: log.warning('The number of pulses has changed during the run.') pulse_ids = np.unique(pattern.pulse_ids(labelled=False, copy=False)) - nds = ds.isel(gh2_pId=pulse_ids + firstFrame) - nds = nds.assign_coords(gh2_pId=pulse_ids) - nds = nds.rename(gh2_pId=dim) + nds = nds.isel({gh2_dim: pulse_ids + firstFrame}) + nds = nds.assign_coords({gh2_dim: pulse_ids}) + nds = nds.rename({gh2_dim: dim}) mask = pattern.pulse_mask(labelled=False) mask = xr.DataArray(mask, dims=['trainId', dim], coords={'trainId': run.train_ids, dim: np.arange(mask.shape[1])}) mask = mask.sel({dim: pulse_ids}) nds = nds.where(mask, drop=True) - return nds + ret = ds[others].merge(nds, join='inner') + return ret -- GitLab