import binascii
import time
from dataclasses import asdict, dataclass

from datetime import datetime, timezone
from hashlib import md5
from pathlib import Path
from shutil import copyfile
from struct import pack, unpack
from typing import Tuple, Union

import h5py
import numpy as np

from cal_tools.calcat_interface2 import (
    CalCatAPIClient,
    CalCatAPIError,
    get_client,
    get_default_caldb_root,
)

CONDITION_NAME_MAX_LENGTH = 60


class InjectionError(Exception):
    ...

class CCVAlreadyInjectedError(InjectionError):
    ...


@dataclass
class ParameterConditionAttribute:
    value: str
    lower_deviation_value: float = 0
    upper_deviation_value: float = 0
    flg_available: bool = True
    description: str = ''


def generate_unique_condition_name(
    detector_type: str,
    pdu_name: str,
    pdu_uuid: float,
    cond_params: dict,
):
    """Generate a unique condition using UUID and timestamp.

    Args:
        detector_type (str): detector type.
        pdu_name (str): Physical detector unit db name.
        pdu_uuid (float): Physical detector unit db id.
        cond_params (dict): Keys DB names, values ParameterConditionAttribute

    Returns:
        str:  A unique name used for the table of conditions.
    """
    unique_name = detector_type[:detector_type.index('-Type')] + ' Def'
    cond_hash = md5(pdu_name.encode())
    cond_hash.update(int(pdu_uuid).to_bytes(
        length=8, byteorder='little', signed=False))

    for pname, pattrs in cond_params.items():
        cond_hash.update(pname.encode())
        cond_hash.update(str(pattrs.value).encode())

    unique_name += binascii.b2a_base64(cond_hash.digest()).decode()
    return unique_name[:CONDITION_NAME_MAX_LENGTH]


def create_unique_cc_name(det_type, calibration, condition_name):
    """
    Generating CC name from condition name,
    detector type, and calibration name.
    """
    cc_name_1 = f'{det_type}_{calibration}'
    # I guess there is a limit to the name characters?
    return f'{cc_name_1[:40]}_{condition_name}'


def create_unique_ccv_name(start_idx):
    # Generate unique name if it doesn't exist
    datetime_str = datetime_str = datetime.now(
        timezone.utc).strftime('%Y%m%d_%H%M%S')
    return f'{datetime_str}_sIdx={start_idx}'


def extract_parameter_conditions(
    ccv_group: dict, pdu_uuid: int) -> dict:
    def _to_string(value):
        """Send only accepted value types to CALCAT."""
        if isinstance(value, bool):
            value = float(value)
        return str(value)

    cond_params = {}
    condition_group = ccv_group['operating_condition']
    # It's really not ideal we're mixing conditionS and condition now.
    # Get parameter data and create a list of `ParameterConditionAttribute`s
    for parameter in condition_group:
        param_dset = condition_group[parameter]
        param_name = param_dset.attrs['database_name']
        cond_params[param_name] = ParameterConditionAttribute(
                value=_to_string(param_dset[()]),
                lower_deviation_value=param_dset.attrs['lower_deviation'],
                upper_deviation_value=param_dset.attrs['upper_deviation'],
            )

    # Add PDU "UUID" to parameters.
    cond_params['Detector UUID'] = ParameterConditionAttribute(
        value=_to_string(unpack('d', pack('q', pdu_uuid))[0]),
    )
    return cond_params


def write_ccv(
    const_path,
    pdu_name, pdu_uuid, detector_type,
    calibration, conditions, created_at, proposal, runs,
    data, dims, key='0', deviations={},
):
    """Write CCV data file.

    Args:
        const_path (os.PathLike): Path to CCV file to write
        pdu_name (str): Physical detector unit name
        pdu_uuid (int): Physical detector unit UUID
        detector_type (str): Detector type name
        calibration (str): Calibration name
        conditions (ConditionsBase): Detector operating conditions
        created_at (datetime): Validity start for calibration
        proposal (int): Raw data proposal the calibration data is
            generated from
        runs (Iterable of int): Raw data runs the calibration data is
            generated from
        data (ndarray): Calibration constant data
        dims (Iterable of str): Dimension names for the constant data.
        key (str, optional): Key added in constant file dataset when
            constant data are stored. Defaults to '0'.
        deviations (dict, optional): Deviation values for operating conditions.
            Each value can be a tuple (lower, upper) or a single value for
            symmetric deviations. Defaults to {}.
            e.g. {"integration_time": 0.025}

    Returns:
        (str) CCV HDF group name.

    """

    if data.ndim != len(dims):
        raise ValueError('data.ndims != len(dims)')

    with h5py.File(const_path, 'a') as const_file:
        const_file.attrs['version'] = 0

        pdu_group = const_file.require_group(pdu_name)
        pdu_group.attrs['uuid'] = pdu_uuid
        pdu_group.attrs['detector_type'] = detector_type

        calibration_group = pdu_group.require_group(calibration)

        if key is None:
            key = str(len(calibration_group))

        ccv_group = calibration_group.create_group(key)
        ccv_group.attrs['begin_at'] = created_at.isoformat()
        ccv_group.attrs['proposal'] = proposal
        ccv_group.attrs['runs'] = np.array(runs, dtype=np.int32)
        ccv_group_name = ccv_group.name

        opcond_group = ccv_group.create_group('operating_condition')
        opcond_dict = conditions.make_dict(
            conditions.calibration_types[calibration])
        for db_name, value in opcond_dict.items():
            cond_name = db_name.lower().replace(' ', '_')
            dset = opcond_group.create_dataset(
                cond_name, data=value, dtype=np.float64)

            deviation = deviations.get(cond_name, (0.0, 0.0))
            if isinstance(deviation, (int, float)):
                lower_dev = upper_dev = deviation
            else:
                lower_dev, upper_dev = deviation

            dset.attrs['lower_deviation'] = lower_dev
            dset.attrs['upper_deviation'] = upper_dev
            dset.attrs['database_name'] = db_name

        dset = ccv_group.create_dataset('data', data=data)
        dset.attrs['dims'] = dims

    return ccv_group_name


def get_condition_dict(
    name: str,
    value: Union[float, str, int, bool],
    lower_deviation: float = 0.0,
    upper_deviation: float = 0.0,
):

    def to_float_or_string(value):
        """CALCAT expects data to either be float or a string.
        """
        try:  # Any digit or boolean
            return float(value)
        except:
            return str(value)

    return {
        'parameter_name': name,
        'value': to_float_or_string(value),
        'lower_deviation_value': lower_deviation,
        'upper_deviation_value': upper_deviation,
        'flg_available': True
    }


def get_or_create_calibration_constant(
    client: CalCatAPIClient,
    calibration: str,
    detector_type: str,
    condition_id: int,
    condition_name: str,
):
    cond_id = condition_id

    # Prepare some parameters to set Calibration Constant.
    cal_id = client.calibration_by_name(calibration)['id']
    det_type_id = client.detector_type_by_name(detector_type)['id']
    cc_name = create_unique_cc_name(detector_type, calibration, condition_name)

    calibration_constant = dict(
        name=cc_name,
        calibration_id=cal_id,
        condition_id=cond_id,
        detector_type_id=det_type_id,
        flg_auto_approve=True,
        flg_available=True,
        description="",
    )
    try:
        cc_id = client.get(
            f"calibrations/{cal_id}/get_calibration_constant",
            calibration_constant
        )
    except CalCatAPIError as e:
        if e.status_code != 404:
            raise
        cc_id = client.post("calibration_constants", calibration_constant)['id']
    return cc_id


def create_condition(
    client: CalCatAPIClient,
    detector_type: str,
    pdu_name: str,
    pdu_uuid: float,
    cond_params: dict,
    ) -> Tuple[int, str]:
    # Create condition unique name
    cond_name = generate_unique_condition_name(
        detector_type, pdu_name, pdu_uuid, cond_params)

    # Create condition table in database, if not available.
    cond = dict(
        name=cond_name,
        parameters_conditions_attributes=[
            asdict(cond) | {"parameter_name": db_name}
            for (db_name, cond) in cond_params.items()
        ],
        flg_available=True,
        description='',
    )
    resp = client.post(
        "conditions/set_expected_condition", {"condition": cond}
    )
    return resp['id'], cond_name


def get_raw_data_location(proposal: str, runs: list):
    if proposal and len(runs) > 0:
        return (
            f'proposal:{proposal} runs: {" ".join([str(x) for x in runs])}')
    else:
        return ""  # Fallback for non-run based constants


def get_ccv_info_from_file(
    cfile: Union[str, Path], pdu: str, ccv_root: str):

    """
    Read CCV HDF5 file to get calibration parameters.

    Args:
        cfile (str, Path): The CalibrationConstantVersion file path.
        pdu (str): The Physical detector unit name for the stored constant.
        ccv_root (str): The CCV root dataset path to access the data.

    Returns:
        List[ParameterConditionAttribute], str, float, str, str
    """
    with h5py.File(cfile, 'r') as const_file:
        pdu_group = const_file[pdu]
        pdu_uuid = pdu_group.attrs['uuid']
        detector_type = pdu_group.attrs['detector_type']
        ccv_group = const_file[ccv_root]
        raw_data_location = get_raw_data_location(
            ccv_group.attrs['proposal'],
            ccv_group.attrs['runs']
        )
        begin_at = ccv_group.attrs['begin_at']
        cond_params = extract_parameter_conditions(ccv_group, pdu_uuid)

    return cond_params, begin_at, pdu_uuid, detector_type, raw_data_location


def inject_ccv(const_src, ccv_root, report_to=None, client=None):
    """Inject new CCV into CalCat.

    Args:
        const_src (str or Path): Path to CCV data file.
        ccv_root (str): CCV HDF group name.
        report_to (str): Metadata location.
        client (CalCatAPIClient, optional): Client for CalCat API.
    Raises:
        RuntimeError: If CalCat POST request fails.
    """
    if client is None:
        client = get_client()

    pdu_name, calibration, _ = ccv_root.lstrip('/').split('/')

    (
        cond_params, begin_at, pdu_uuid, detector_type, raw_data_location
    ) = get_ccv_info_from_file(const_src, pdu_name, ccv_root)

    const_rel_path = f'xfel/cal/{detector_type.lower()}/{pdu_name.lower()}'
    const_filename = f'cal.{time.time()}.h5'
    const_dest = get_default_caldb_root() / const_rel_path / const_filename
    const_dest.parent.mkdir(parents=True, exist_ok=True)
    copyfile(const_src, const_dest)

    try:
        condition_id, condition_name = create_condition(
            client, detector_type, pdu_name, pdu_uuid, cond_params)

        # Create Calibration Constant in database, if not available.
        cc_id = get_or_create_calibration_constant(
            client, calibration, detector_type, condition_id, condition_name)

        # Create report in database, if not available
        report_id = None
        if report_to:
            report_path = Path(report_to).absolute().with_suffix('.pdf')
            resp = client.post("reports/set", dict(
                name=report_path.name,
                file_path=str(report_path),
                flg_available=True,
                description="",
            ))
            report_id = resp['id']

        # Get PDU ID before creating new CCV.
        pdu_id = client.pdu_by_name(pdu_name)['id']

        # Prepare CCV data and inject it to CALCAT.
        start_idx = 0
        ccv = dict(
            name=create_unique_ccv_name(start_idx),
            file_name=const_filename,
            path_to_file=const_rel_path,
            data_set_name=ccv_root,
            calibration_constant_id=cc_id,
            physical_detector_unit_id=pdu_id,
            raw_data_location=raw_data_location,
            report_id=report_id,
            begin_validity_at=begin_at,
            end_validity_at='',
            begin_at=begin_at,
            start_idx=start_idx,
            end_idx=0,
            flg_deployed=True,
            flg_good_quality=True,
            description='',
        )
        try:
            client.post("calibration_constant_versions", ccv)
        except CalCatAPIError as e:
            if e.status_code == 422:
                raise CCVAlreadyInjectedError
            raise
    except Exception:
        const_dest.unlink()  # Delete already copied CCV file.
        raise