import sys
import warnings
import os
import h5py
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import datetime
import dateutil.parser

import seaborn
from iCalibrationDB import ConstantMetaData, Constants, Conditions, Detectors, Versions
from cal_tools.cal_tools import gain_map_files, parse_runs, run_prop_seq_from_path, get_notebook_name, get_dir_creation_date, get_constant_from_db
from cal_tools.influx import InfluxLogger
from cal_tools.enums import BadPixels
from cal_tools.plotting import show_overview, plot_badpix_3d, create_constant_overview

warnings.filterwarnings('ignore')
matplotlib.use("agg")


start_date = "2018-06-25" # date to start investigation interval from
end_date = "now" # date to end investigation interval at, can be "now"
interval = 3 # interval for evaluation in days
detector = "LPD1M1" # detector to investigate
constants = ["Offset", "Noise"] # constants to plot
cal_db_interface = "tcp://max-exfl016:8015" # the database interface to use
bias_voltage = 500
max_cells = 128
modules = [3]
out_folder = "/gpfs/exfel/data/scratch/karnem/test/"
all_ready_printed = {}


detector = detector.upper()
dclass = "AGIPD" if "AGIPD" in detector else "LPD"


if modules[0] == -1:
    modules = list(range(16))


def get_constant_from_db(device, constant, condition, empty_constant,
                         cal_db_interface, creation_time = None, 
                         print_once=True):
    
    if device:
        metadata = ConstantMetaData()        
        metadata.calibration_constant = constant
        metadata.detector_condition = condition
        if creation_time is None:
            metadata.calibration_constant_version = Versions.Now(device=device)
        else:
            metadata.calibration_constant_version = \
                Versions.Timespan(device=device, start=creation_time)

        try:
            print ("Send requiests")
            metadata.retrieve(cal_db_interface, when=creation_time.isoformat())
            if constant.name not in all_ready_printed or not print_once:
                all_ready_printed[constant.name] = True
                print("{} was injected on: {}".format(constant.name, 
                      metadata.calibration_constant_version.begin_at))  
            return constant.data
        except Exception as e:
            return empty_constant
    else:
        return empty_constant


dt = dateutil.parser.parse(start_date)
end = datetime.datetime.now() if end_date.upper() == "NOW" else dateutil.parser.parse(end_date)
step = datetime.timedelta(days=interval)

det = getattr(Detectors, detector)
dconstants = getattr(Constants, dclass)
ret_constants = {}
while dt < end:
    creation_time = dt
    print("Retreiving data from: {}".format(creation_time.isoformat()))
    for const in constants:
        if not const in ret_constants:
            ret_constants[const] = {}
        if (const in ["Offset", "Noise"] or "DARK" in const.upper()):
            dcond = Conditions.Dark  
        else:
            Conditions.Illuminated

        for i in modules:
            qm = "Q{}M{}".format(i//4+1, i%4+1)
            cdata = get_constant_from_db(getattr(det, qm),
                                         getattr(dconstants, const)(),
                                         getattr(dcond, dclass)(memory_cells=max_cells,
                                                                bias_voltage=bias_voltage),
                                         None,
                                         cal_db_interface,
                                         creation_time=creation_time)

            print("Found constant {}: {}".format(const, cdata is not None))

            if not qm in ret_constants[const]:
                ret_constants[const][qm] = []
            
            if cdata is not None:
                carr = np.zeros((5, max_cells, 3))
                carr_glob = np.zeros((5, 3))
                carr_px = np.zeros((cdata.shape[0], max_cells, 3, 2))
                for g in range(3):
                    td = np.nanmean(cdata[...,g], axis=(0,1))
                    print (td.shape)
                    carr[0,:td.shape[0],g] = td

                    td = np.nanmedian(cdata[...,g], axis=(0,1))
                    carr[1,:td.shape[0],g] = td

                    td = np.nanmin(cdata[...,g], axis=(0,1))
                    carr[2,:td.shape[0],g] = td

                    td = np.nanmax(cdata[...,g], axis=(0,1))
                    carr[3,:td.shape[0],g] = td
                    
                    td = np.nanstd(cdata[...,g], axis=(0,1))
                    carr[4,:td.shape[0],g] = td

                    td = np.nanmean(cdata[...,g])
                    carr_glob[0, g] = td

                    td = np.nanmedian(cdata[...,g])
                    carr_glob[1, g] = td

                    td = np.nanmin(cdata[...,g])
                    carr_glob[2, g] = td

                    td = np.nanmax(cdata[...,g])
                    carr_glob[3, g] = td
                    
                    td = np.nanstd(cdata[...,g])
                    carr_glob[3, g] = td
                    
                    carr_px[...,g, 0] = np.nanmedian(cdata[...,g], axis=0)
                    carr_px[...,g, 1] = np.nanmedian(cdata[...,g], axis=1)

                ret_constants[const][qm].append((creation_time,
                                        (carr, carr_glob, carr_px)))
                
    dt += step


types = ["mean", "median", "min", "max", "std"]
colors = ["red", "green", "orange", "blue"]
skip = [False, False, False, False, False]

# loop over constat type
for const, modules in ret_constants.items():
    fig = plt.figure(figsize=(15,7))
    tt = 0
    print (const)

    # loop over type of stat parameter
    for typ in range(len(types)):
        if skip[typ]:
            continue
        ax = plt.subplot2grid((np.count_nonzero(~np.array(skip)), 1), (tt, 0))
        
        # loop over modules
        for mod, data in modules.items():
            ctimes, cd = list(zip(*data))
            pmm, glob, _ = list(zip(*cd))
            pma = np.array(pmm)
            ga = np.array(glob)
            d = pma[:,typ,:,:]

            if np.allclose(d, 0):
                continue
            dd = pma[:,typ,:,:]#-pma[0,typ,:,:])/pma[0,typ,:,:]
            y = dd.flatten()
            x = np.repeat(np.array(ctimes)[:,None],
                 dd[0,:,:].size, axis=1).flatten()
            hue = np.repeat(np.array(['gain 0', 'gain 1', 'gain 2'])[:,None],
                   dd[:,:,0].size, axis=1).swapaxes(0,1).flatten()
            seaborn.violinplot(x, y, hue, scale="width", dodge=False, saturation=0.7)

            #ax.set_ylim(-0.25, .25)
            
        if typ != len(types)-1:
            ax.axes.get_xaxis().set_visible(False)
        else:
            def format_date(x, pos=None):
                return ctimes[x].strftime('%d-%m')
            ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_date))
            ax.set_xlabel("Date")
        ax.set_ylabel("{}".format(types[typ]))
        
        tt += 1
    plt.subplots_adjust(wspace=0.2, hspace=0.2)
    
    if out_folder != "":
        fig.savefig("{}/{}_time_development.pdf".format(out_folder, const), 
             bbox_inches='tight')
    

    fig = plt.figure(figsize=(15,7))
    ax = plt.subplot2grid((1, 1), (0, 0))
        
    # loop over modules
    for mod, data in modules.items():
        ctimes, cd = list(zip(*data))
        _, _, px = list(zip(*cd))
        px = np.array(px)
        print (px.shape)
        y = px[:,:,5,0,:].flatten()
        x = np.repeat(np.array(ctimes)[:,None],
                px[0,:,5,0,:].size, axis=1).flatten()
        hue = np.repeat(np.array(['px','py'])[:,None], px[:,:,5,0,0].size, 
                axis=1).swapaxes(0,1).flatten()

        seaborn.violinplot(x, y, hue, palette="muted", split=True)

        def format_date(x, pos=None):
            return ctimes[x].strftime('%d-%m')
        ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_date))
        ax.set_xlabel("Date")
        ax.set_ylabel("Median over pixels")
        
    plt.subplots_adjust(wspace=0.2, hspace=0.2)
    
    if out_folder != "":
        fig.savefig("{}/{}_pxtime_development.pdf".format(out_folder, const),
             bbox_inches='tight')