""" Gotthard-II detector related sub-routines

    Copyright (2024) SCS Team.

    (contributions preferrably comply with pep8 code structure
    guidelines.)
"""
from extra.components import OpticalLaserPulses, XrayPulses
import numpy as np
import xarray as xr
import logging
__all__ = [
    'extract_GH2',
]

log = logging.getLogger(__name__)


def extract_GH2(ds, run, firstFrame=0, bunchPattern='scs_ppl',
                gh2_dim='gh2_pId'):
    '''
    Select and align the frames of the Gotthard-II that have been exposed
    to light.

    Parameters
    ------
    ds: xarray.Dataset
        The dataset containing GH2 data
    run: extra_data.DataCollection
        The run containing the bunch pattern source
    firstFrame: int
        The GH2 frame number corresponding to the first pulse of the train.
    bunchPattern: str in ['scs_ppl', 'sase3']
        the bunch pattern used to align data. For 'scs_ppl', the gh2_pId
        dimension in renamed 'ol_pId', and for 'sase3' gh2_pId is renamed
        'sa3_pId'.
    gh2_dim: str
        The name of the dimension that corresponds to the Gotthard-II frames.

    Returns
    -------
    nds: xarray Dataset
        The aligned and reduced dataset with only-data-containing GH2
        variables.
    '''
    if gh2_dim not in ds.dims:
        log.warning(f'gh2_dim "{gh2_dim}" not in dataset. Skipping.')
        return ds
    if bunchPattern == 'scs_ppl':
        pattern = OpticalLaserPulses(run)
        dim = 'ol_pId'
    else:
        pattern = XrayPulses(run)
        dim = 'sa3_pId'
    others = [var for var in ds if dim in ds[var].coords]
    nds = ds.drop_dims(dim, errors='ignore')
    if pattern.is_constant_pattern():
        pulse_ids = pattern.peek_pulse_ids(labelled=False)
        nds = nds.isel({gh2_dim: pulse_ids + firstFrame})
        nds = nds.assign_coords({gh2_dim: pulse_ids})
        nds = nds.rename({gh2_dim: dim})
    else:
        log.warning('The number of pulses has changed during the run.')
        pulse_ids = np.unique(pattern.pulse_ids(labelled=False, copy=False))
        nds = nds.isel({gh2_dim: pulse_ids + firstFrame})
        nds = nds.assign_coords({gh2_dim: pulse_ids})
        nds = nds.rename({gh2_dim: dim})
        mask = pattern.pulse_mask(labelled=False)
        mask = xr.DataArray(mask, dims=['trainId', dim],
                            coords={'trainId': run.train_ids,
                                    dim: np.arange(mask.shape[1])})
        mask = mask.sel({dim: pulse_ids})
        nds = nds.where(mask, drop=True)
    ret = ds[others].merge(nds, join='inner')
    return ret