from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Optional, Sequence, Union

from calibration_client import CalibrationClient
from calibration_client.modules import CalibrationConstantVersion

from .calcat_interface import CalCatApi, CalCatError
from .tools import module_index_to_qm

global_client = None


def get_client():
    global global_client
    if global_client is None:
        setup_client("http://exflcalproxy:8080/", None, None, None)
    return global_client


def setup_client(base_url, client_id, client_secret, user_email, **kwargs):
    global global_client
    global_client = CalibrationClient(
        use_oauth2=(client_id is not None),
        client_id=client_id,
        client_secret=client_secret,
        user_email=user_email,
        base_api_url=f"{base_url}/api/",
        token_url=f"{base_url}/oauth/token",
        refresh_url=f"{base_url}/oauth/token",
        auth_url=f"{base_url}/oauth/authorize",
        scope="",
        **kwargs,
    )


@dataclass
class SingleConstantVersion:
    """A Calibration Constant Version for 1 detector module"""

    id: int
    version_name: str
    constant_id: int
    constant_name: str
    condition_id: int
    path: Path
    dataset: str
    begin_validity_at: datetime
    end_validity_at: datetime
    raw_data_location: str
    physical_name: str  # PDU name

    @classmethod
    def from_response(cls, ccv: dict) -> "SingleConstantVersion":
        const = ccv["calibration_constant"]
        return cls(
            id=ccv["id"],
            version_name=ccv["name"],
            constant_id=const["id"],
            constant_name=const["name"],
            condition_id=const["condition_id"],
            path=Path(ccv["path_to_file"]) / ccv["file_name"],
            dataset=ccv["data_set_name"],
            begin_validity_at=ccv["begin_validity_at"],
            end_validity_at=ccv["end_validity_at"],
            raw_data_location=ccv["raw_data_location"],
            physical_name=ccv["physical_detector_unit"]["physical_name"],
        )


@dataclass
class ModulesConstantVersions:
    """A group of similar CCVs for several modules of one detector"""

    constants: Dict[str, SingleConstantVersion]  # Keys e.g. 'LPD00'

    def select_modules(self, *aggregators) -> "ModulesConstantVersions":
        d = {aggr: scv for (aggr, scv) in self.constants.items() if aggr in aggregators}
        return ModulesConstantVersions(d)

    @property
    def aggregators(self):
        return sorted(self.constants)

    @property
    def module_nums(self):
        return [int(da[-2:]) for da in self.aggregators]

    @property
    def qm_names(self):
        return [module_index_to_qm(n) for n in self.module_nums]


class CalibrationData(Mapping):
    """Collected constants for a given detector"""

    def __init__(self, constant_groups: Dict[str, Dict[str, SingleConstantVersion]]):
        self.constant_groups = {
            const_type: ModulesConstantVersions(d)
            for const_type, d in constant_groups.items()
        }

    @classmethod
    def from_condition(
        cls,
        condition: "ConditionsBase",
        detector_name,
        modules: Optional[Sequence[str]] = None,
        calibrations=None,
        client=None,
        event_at=None,
        pdu_snapshot_at=None,
    ):
        if calibrations is None:
            calibrations = set(condition.calibration_types)

        cal_types_by_params_used = {}
        for cal_type, params in condition.calibration_types.items():
            if cal_type in calibrations:
                cal_types_by_params_used.setdefault(tuple(params), []).append(cal_type)

        api = CalCatApi(client or get_client())

        d = {}

        for params, cal_types in cal_types_by_params_used.items():
            condition_dict = condition.make_dict(params)

            cal_id_map = {
                api.calibration_id(calibration): calibration
                for calibration in calibrations
            }
            calibration_ids = list(cal_id_map.keys())

            query_res = api._closest_ccv_by_time_by_condition(
                detector_name,
                calibration_ids,
                condition_dict,
                modules[0] if len(modules) == 1 else "",
                event_at,
                pdu_snapshot_at or event_at,
            )

            for ccv in query_res:
                aggr = ccv["physical_detector_unit"]["karabo_da"]
                cal_type = cal_id_map[ccv["calibration_constant"]["calibration_id"]]

                d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(
                    ccv
                )

        res = cls(d)
        if modules:
            res = res.select_modules(modules)
        return res

    @classmethod
    def from_report(
        cls,
        report_id_or_path: Union[int, str],
        client=None,
    ):
        client = client or get_client()
        api = CalCatApi(client)

        if isinstance(report_id_or_path, int):
            resp = CalibrationConstantVersion.get_by_report_id(
                client, report_id_or_path
            )
        else:
            resp = CalibrationConstantVersion.get_by_report_path(
                client, report_id_or_path
            )

        if not resp["success"]:
            raise CalCatError(resp)

        d = {}

        for ccv in resp["data"]:
            aggr = ccv["physical_detector_unit"]["karabo_da"]
            cal_type = api.calibration_name(
                ccv["calibration_constant"]["calibration_id"]
            )
            d.setdefault(cal_type, {})[aggr] = SingleConstantVersion.from_response(ccv)

        return cls(d)

    def __getitem__(self, key) -> ModulesConstantVersions:
        return self.constant_groups[key]

    def __iter__(self):
        return iter(self.constant_groups)

    def __len__(self):
        return len(self.constant_groups)

    @property
    def aggregators(self):
        names = set()
        for mcv in self.constant_groups.values():
            names.update(mcv.aggregators)
        return sorted(names)

    @property
    def module_nums(self):
        return [int(da[-2:]) for da in self.aggregators]

    @property
    def qm_names(self):
        return [module_index_to_qm(n) for n in self.module_nums]

    def require_calibrations(self, *calibrations):
        """Drop any modules missing the specified constant types"""
        mods = set(self.aggregators)
        for cal_type in calibrations:
            mods.intersection_update(self[cal_type].constants)
        return self.select_modules(mods)

    def select_calibrations(self, *calibrations, require_all=True):
        if require_all:
            missing = set(calibrations) - set(self.constant_groups)
            if missing:
                raise KeyError(f"Missing calibrations: {', '.join(sorted(missing))}")

        d = {
            cal_type: mcv.constants
            for (cal_type, mcv) in self.constant_groups.items()
            if cal_type in calibrations
        }
        # TODO: missing for some modules?
        return type(self)(d)

    def select_modules(self, *aggregators):
        return type(self)(
            {
                cal_type: mcv.select_modules(*aggregators).constants
                for (cal_type, mcv) in self.constant_groups.items()
            }
        )

    def merge(self, *others: Sequence["CalibrationData"]) -> "CalibrationData":
        d = {}
        for cal_type, mcv in self.constant_groups.items():
            d[cal_type] = mcv.constants.copy()
            for other in others:
                if cal_type in other:
                    d[cal_type].update(other[cal_type].constants)

        return type(self)(d)


class ConditionsBase:
    calibration_types = {}  # For subclasses: {calibration: [parameter names]}

    def make_dict(self, parameters) -> dict:
        d = dict()

        for db_name in parameters:
            value = getattr(self, db_name.lower().replace(" ", "_"))
            if value is not None:
                d[db_name] = value

        return d


@dataclass
class AGIPDConditions(ConditionsBase):
    sensor_bias_voltage: float
    memory_cells: int
    acquisition_rate: float
    gain_setting: Optional[int]
    gain_mode: Optional[int]
    source_energy: float
    integration_time: int = 12
    pixels_x: int = 512
    pixels_y: int = 128

    _gain_parameters = [
        "Sensor Bias Voltage",
        "Pixels X",
        "Pixels Y",
        "Memory cells",
        "Acquisition rate",
        "Gain setting",
        "Integration time",
    ]
    _other_dark_parameters = _gain_parameters + ["Gain mode"]
    _illuminated_parameters = _gain_parameters + ["Source energy"]

    calibration_types = {
        "Offset": _other_dark_parameters,
        "Noise": _other_dark_parameters,
        "ThresholdsDark": _other_dark_parameters,
        "BadPixelsDark": _other_dark_parameters,
        "BadPixelsPC": _gain_parameters,
        "SlopesPC": _gain_parameters,
        "BadPixelsFF": _illuminated_parameters,
        "SlopesFF": _illuminated_parameters,
    }

    def make_dict(self, parameters):
        cond = super().make_dict(parameters)

        # Fix-up some database quirks.
        if int(cond.get("Gain mode", -1)) == 0:
            del cond["Gain mode"]

        if int(cond.get("Integration time", -1)) == 12:
            del cond["Integration time"]

        return cond


@dataclass
class LPDConditions(ConditionsBase):
    sensor_bias_voltage: float
    memory_cells: int
    memory_cell_order: Optional[str] = None
    feedback_capacitor: float = 5.0
    source_energy: float = 9.2
    category: int = 1
    pixels_x: int = 256
    pixels_y: int = 256

    _base_params = [
        "Sensor Bias Voltage",
        "Memory cells",
        "Pixels X",
        "Pixels Y",
        "Feedback capacitor",
    ]
    _dark_parameters = _base_params + [
        "Memory cell order",
    ]
    _illuminated_parameters = _base_params + ["Source Energy", "category"]

    calibration_types = {
        "Offset": _dark_parameters,
        "Noise": _dark_parameters,
        "BadPixelsDark": _dark_parameters,
        "RelativeGain": _illuminated_parameters,
        "GainAmpMap": _illuminated_parameters,
        "FFMap": _illuminated_parameters,
        "BadPixelsFF": _illuminated_parameters,
    }