diff --git a/src/toolbox_scs/detectors/gotthard2.py b/src/toolbox_scs/detectors/gotthard2.py index b420aac960f6a7070483e0aab410dc0e5c9be8a6..bf2988cafbcaf050b7e17f98cbf8fe99d16a5545 100644 --- a/src/toolbox_scs/detectors/gotthard2.py +++ b/src/toolbox_scs/detectors/gotthard2.py @@ -15,9 +15,13 @@ __all__ = [ log = logging.getLogger(__name__) -def extract_GH2(ds, run, firstFrame=0, bunchPattern='scs_ppl'): + +def extract_GH2(ds, run, firstFrame=0, bunchPattern='scs_ppl', + gh2_dim='gh2_pId'): ''' - Extract the frames of the Gotthard-II that have been exposed to light. + Select and align the frames of the Gotthard-II that have been exposed + to light. + Parameters ------ ds: xarray.Dataset @@ -30,33 +34,42 @@ def extract_GH2(ds, run, firstFrame=0, bunchPattern='scs_ppl'): the bunch pattern used to align data. For 'scs_ppl', the gh2_pId dimension in renamed 'ol_pId', and for 'sase3' gh2_pId is renamed 'sa3_pId'. + gh2_dim: str + The name of the dimension that corresponds to the Gotthard-II frames. Returns ------- nds: xarray Dataset - The aligned and reduced dataset with only-data-containing GH2 variables. + The aligned and reduced dataset with only-data-containing GH2 + variables. ''' + if gh2_dim not in ds.dims: + log.warning(f'gh2_dim "{gh2_dim}" not in dataset. Skipping.') + return ds if bunchPattern == 'scs_ppl': pattern = OpticalLaserPulses(run) - dim='ol_pId' + dim = 'ol_pId' else: pattern = XrayPulses(run) - dim='sa3_pId' + dim = 'sa3_pId' + others = [var for var in ds if dim in ds[var].coords] + nds = ds.drop_dims(dim) if pattern.is_constant_pattern(): pulse_ids = pattern.peek_pulse_ids(labelled=False) - nds = ds.isel(gh2_pId=pulse_ids + firstFrame) - nds = nds.assign_coords(gh2_pId=pulse_ids) - nds = nds.rename(gh2_pId=dim) + nds = nds.isel({gh2_dim: pulse_ids + firstFrame}) + nds = nds.assign_coords({gh2_dim: pulse_ids}) + nds = nds.rename({gh2_dim: dim}) else: log.warning('The number of pulses has changed during the run.') pulse_ids = np.unique(pattern.pulse_ids(labelled=False, copy=False)) - nds = ds.isel(gh2_pId=pulse_ids + firstFrame) - nds = nds.assign_coords(gh2_pId=pulse_ids) - nds = nds.rename(gh2_pId=dim) + nds = nds.isel({gh2_dim: pulse_ids + firstFrame}) + nds = nds.assign_coords({gh2_dim: pulse_ids}) + nds = nds.rename({gh2_dim: dim}) mask = pattern.pulse_mask(labelled=False) mask = xr.DataArray(mask, dims=['trainId', dim], coords={'trainId': run.train_ids, dim: np.arange(mask.shape[1])}) mask = mask.sel({dim: pulse_ids}) nds = nds.where(mask, drop=True) - return nds + ret = ds[others].merge(nds, join='inner') + return ret