from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Sequence, Union

import h5py
import pasha as psh
from calibration_client import CalibrationClient
from calibration_client.modules import CalibrationConstantVersion

from .calcat_interface import CalCatApi, CalCatError
from .tools import module_index_to_qm


class ModuleNameError(KeyError):
    def __init__(self, name):
        self.name = name

    def __str__(self):
        return f"No module named {self.name!r}"


global_client = None


def get_client():
    global global_client
    if global_client is None:
        setup_client("http://exflcalproxy:8080/", None, None, None)
    return global_client


def setup_client(base_url, client_id, client_secret, user_email, **kwargs):
    global global_client
    global_client = CalibrationClient(
        use_oauth2=(client_id is not None),
        client_id=client_id,
        client_secret=client_secret,
        user_email=user_email,
        base_api_url=f"{base_url}/api/",
        token_url=f"{base_url}/oauth/token",
        refresh_url=f"{base_url}/oauth/token",
        auth_url=f"{base_url}/oauth/authorize",
        scope="",
        **kwargs,
    )


_default_caldb_root = ...


def _get_default_caldb_root():
    global _default_caldb_root
    if _default_caldb_root is ...:
        onc_path = Path("/common/cal/caldb_store")
        maxwell_path = Path("/gpfs/exfel/d/cal/caldb_store")
        if onc_path.is_dir():
            _default_caldb_root = onc_path
        elif maxwell_path.is_dir():
            _default_caldb_root = maxwell_path
        else:
            _default_caldb_root = None

    return _default_caldb_root


@dataclass
class SingleConstantVersion:
    """A Calibration Constant Version for 1 detector module"""

    id: int
    version_name: str
    constant_id: int
    constant_name: str
    condition_id: int
    path: Path
    dataset: str
    begin_validity_at: datetime
    end_validity_at: datetime
    raw_data_location: str
    physical_name: str  # PDU name

    @classmethod
    def from_response(cls, ccv: dict) -> "SingleConstantVersion":
        const = ccv["calibration_constant"]
        return cls(
            id=ccv["id"],
            version_name=ccv["name"],
            constant_id=const["id"],
            constant_name=const["name"],
            condition_id=const["condition_id"],
            path=Path(ccv["path_to_file"]) / ccv["file_name"],
            dataset=ccv["data_set_name"],
            begin_validity_at=ccv["begin_validity_at"],
            end_validity_at=ccv["end_validity_at"],
            raw_data_location=ccv["raw_data_location"],
            physical_name=ccv["physical_detector_unit"]["physical_name"],
        )

    def dataset_obj(self, caldb_root=None):
        if caldb_root is not None:
            caldb_root = Path(caldb_root)
        else:
            caldb_root = _get_default_caldb_root()

        f = h5py.File(caldb_root / self.path, "r")
        return f[self.dataset]["data"]

    def ndarray(self, caldb_root=None):
        return self.dataset_obj(caldb_root)[:]


@dataclass
class ModulesConstantVersions:
    """A group of similar CCVs for several modules of one detector"""

    constants: Dict[str, SingleConstantVersion]  # Keys e.g. 'LPD00'

    def select_modules(self, aggregators) -> "ModulesConstantVersions":
        d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggregators}
        return ModulesConstantVersions(d)

    @property
    def aggregators(self):
        return sorted(self.constants)

    @property
    def module_nums(self):
        return [int(da[-2:]) for da in self.aggregators]

    @property
    def qm_names(self):
        return [module_index_to_qm(n) for n in self.module_nums]


class CalibrationData(Mapping):
    """Collected constants for a given detector"""

    def __init__(self, constant_groups, aggregators):
        self.constant_groups = {
            const_type: ModulesConstantVersions(d)
            for const_type, d in constant_groups.items()
        }
        self.aggregators = aggregators

    @classmethod
    def from_condition(
        cls,
        condition: "ConditionsBase",
        detector_name,
        modules: Optional[Sequence[str]] = None,
        calibrations=None,
        client=None,
        event_at=None,
        pdu_snapshot_at=None,
    ):
        if calibrations is None:
            calibrations = set(condition.calibration_types)
        if pdu_snapshot_at is None:
            pdu_snapshot_at = event_at

        cal_types_by_params_used = {}
        for cal_type, params in condition.calibration_types.items():
            if cal_type in calibrations:
                cal_types_by_params_used.setdefault(tuple(params), []).append(cal_type)

        api = CalCatApi(client or get_client())

        detector_id = api.detector(detector_name)["id"]
        all_modules = api.physical_detector_units(detector_id, pdu_snapshot_at)
        if modules is None:
            modules = sorted(all_modules)
        else:
            modules = sorted(modules)
            for m in modules:
                if m not in all_modules:
                    raise ModuleNameError(m)

        d = {}

        for params, cal_types in cal_types_by_params_used.items():
            condition_dict = condition.make_dict(params)

            cal_id_map = {
                api.calibration_id(calibration): calibration
                for calibration in cal_types
            }
            calibration_ids = list(cal_id_map.keys())

            query_res = api._closest_ccv_by_time_by_condition(
                detector_name,
                calibration_ids,
                condition_dict,
                modules[0] if len(modules) == 1 else "",
                event_at,
                pdu_snapshot_at or event_at,
            )

            for ccv in query_res:
                aggr = ccv["physical_detector_unit"]["karabo_da"]
                cal_type = cal_id_map[ccv["calibration_constant"]["calibration_id"]]

                d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(
                    ccv
                )

        res = cls(d, modules)
        if modules:
            res = res.select_modules(modules)
        return res

    @classmethod
    def from_report(
        cls,
        report_id_or_path: Union[int, str],
        client=None,
    ):
        client = client or get_client()
        api = CalCatApi(client)

        if isinstance(report_id_or_path, int):
            resp = CalibrationConstantVersion.get_by_report_id(
                client, report_id_or_path
            )
        else:
            resp = CalibrationConstantVersion.get_by_report_path(
                client, report_id_or_path
            )

        if not resp["success"]:
            raise CalCatError(resp)

        d = {}
        aggregators = set()

        for ccv in resp["data"]:
            aggr = ccv["physical_detector_unit"]["karabo_da"]
            aggregators.add(aggr)
            cal_type = api.calibration_name(
                ccv["calibration_constant"]["calibration_id"]
            )
            d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(ccv)

        return cls(d, sorted(aggregators))

    def __getitem__(self, key) -> ModulesConstantVersions:
        return self.constant_groups[key]

    def __iter__(self):
        return iter(self.constant_groups)

    def __len__(self):
        return len(self.constant_groups)

    @property
    def module_nums(self):
        return [int(da[-2:]) for da in self.aggregators]

    @property
    def qm_names(self):
        return [module_index_to_qm(n) for n in self.module_nums]

    def require_calibrations(self, calibrations):
        """Drop any modules missing the specified constant types"""
        mods = set(self.aggregators)
        for cal_type in calibrations:
            mods.intersection_update(self[cal_type].constants)
        return self.select_modules(mods)

    def select_calibrations(self, calibrations, require_all=True):
        if require_all:
            missing = set(calibrations) - set(self.constant_groups)
            if missing:
                raise KeyError(f"Missing calibrations: {', '.join(sorted(missing))}")

        d = {
            cal_type: mcv.constants
            for (cal_type, mcv) in self.constant_groups.items()
            if cal_type in calibrations
        }
        # TODO: missing for some modules?
        return type(self)(d, self.aggregators)

    def select_modules(self, aggregators):
        return type(self)(
            {
                cal_type: mcv.select_modules(aggregators).constants
                for (cal_type, mcv) in self.constant_groups.items()
            },
            sorted(aggregators),
        )

    def merge(self, *others: "CalibrationData") -> "CalibrationData":
        d = {}
        for cal_type, mcv in self.constant_groups.items():
            d[cal_type] = mcv.constants.copy()
            for other in others:
                if cal_type in other:
                    d[cal_type].update(other[cal_type].constants)

        aggregators = set(self.aggregators)
        for other in others:
            aggregators.update(other.aggregators)

        return type(self)(d, sorted(aggregators))

    def load_all(self, caldb_root=None):
        res = {}

        const_load_mp = psh.ProcessContext(num_workers=24)
        keys = []
        for cal_type, mcv in self.constant_groups.items():
            res[cal_type] = {}
            for module in mcv.aggregators:
                dset = mcv.constants[module].dataset_obj(caldb_root)
                res[cal_type][module] = const_load_mp.alloc(
                    shape=dset.shape, dtype=dset.dtype
                )
                keys.append((cal_type, module))

        def _load_constant_dataset(wid, index, key):
            cal_type, mod = key
            dset = self[cal_type].constants[mod].dataset_obj(caldb_root)
            dset.read_direct(res[cal_type][mod])

        const_load_mp.map(_load_constant_dataset, keys)


class ConditionsBase:
    calibration_types = {}  # For subclasses: {calibration: [parameter names]}

    def make_dict(self, parameters) -> dict:
        d = dict()

        for db_name in parameters:
            value = getattr(self, db_name.lower().replace(" ", "_"))
            if value is not None:
                d[db_name] = value

        return d


@dataclass
class AGIPDConditions(ConditionsBase):
    sensor_bias_voltage: float
    memory_cells: int
    acquisition_rate: float
    gain_setting: Optional[int]
    gain_mode: Optional[int]
    source_energy: float
    integration_time: int = 12
    pixels_x: int = 512
    pixels_y: int = 128

    _gain_parameters = [
        "Sensor Bias Voltage",
        "Pixels X",
        "Pixels Y",
        "Memory cells",
        "Acquisition rate",
        "Gain setting",
        "Integration time",
    ]
    _other_dark_parameters = _gain_parameters + ["Gain mode"]
    _illuminated_parameters = _gain_parameters + ["Source energy"]

    calibration_types = {
        "Offset": _other_dark_parameters,
        "Noise": _other_dark_parameters,
        "ThresholdsDark": _other_dark_parameters,
        "BadPixelsDark": _other_dark_parameters,
        "BadPixelsPC": _gain_parameters,
        "SlopesPC": _gain_parameters,
        "BadPixelsFF": _illuminated_parameters,
        "SlopesFF": _illuminated_parameters,
    }

    def make_dict(self, parameters):
        cond = super().make_dict(parameters)

        # Fix-up some database quirks.
        if int(cond.get("Gain mode", -1)) == 0:
            del cond["Gain mode"]

        if int(cond.get("Integration time", -1)) == 12:
            del cond["Integration time"]

        return cond


@dataclass
class LPDConditions(ConditionsBase):
    sensor_bias_voltage: float
    memory_cells: int
    memory_cell_order: Optional[str] = None
    feedback_capacitor: float = 5.0
    source_energy: float = 9.2
    category: int = 1
    pixels_x: int = 256
    pixels_y: int = 256

    _base_params = [
        "Sensor Bias Voltage",
        "Memory cells",
        "Pixels X",
        "Pixels Y",
        "Feedback capacitor",
    ]
    _dark_parameters = _base_params + [
        "Memory cell order",
    ]
    _illuminated_parameters = _base_params + ["Source Energy", "category"]

    calibration_types = {
        "Offset": _dark_parameters,
        "Noise": _dark_parameters,
        "BadPixelsDark": _dark_parameters,
        "RelativeGain": _illuminated_parameters,
        "GainAmpMap": _illuminated_parameters,
        "FFMap": _illuminated_parameters,
        "BadPixelsFF": _illuminated_parameters,
    }