import os
import logging
import argparse
import h5py

import numpy as np
import extra_data as ed

import toolbox_scs as tb
import toolbox_scs.detectors as tbdet

logging.basicConfig(level=logging.INFO)
log_root = logging.getLogger(__name__)


# -----------------------------------------------------------------------------
# user input: run-type specific
# -----------------------------------------------------------------------------
proposal_nb = 2599
output_filepath = "../processed_runs/"

# these get set by the shell script now! (e.g. "--runtype static")
# runtype = 'energyscan'
# runtype = 'energyscan_pumped'
# runtype = 'static'
# runtype = 'static_IR'
# runtype = 'delayscan'
# runtype = 'timescan'

# useful metadata to be added to h5 files
scriptname = os.path.basename(__file__)

save_xgm_binned = True

# optional prebinning methods for DSSC data
normevery = 2 # 2 if use intradark, 1 otherwise
xgm_mask = True  # True: xgm_threshold will be used to drop corresponding DSSC frames accordingly to the xgm treshold
xgm_threshold = (1000, np.inf) # or you mean bad pulses here ?
filename_dark = None # 200
xgm_normalization = False 

# -----------------------------------------------------------------------------


def process(run_nb, runtype, modules=[]):
    run_description = f'{runtype}; script {scriptname}'
    print(run_description)
    mod_list = modules
    if len(mod_list)==0:
        mod_list = [i for i in range(16)]

    path = f'{output_filepath}r_{run_nb}/'
    log_root.info("create run objects")
    run_info = tbdet.load_dssc_info(proposal_nb, run_nb)
    fpt = run_info['frames_per_train']
    n_trains = run_info['number_of_trains']
    trainIds = run_info['trainIds']

    # -------------------------------------------------------------------------
    # user input: run specific
    # -------------------------------------------------------------------------
    run_obj = ed.open_run(proposal_nb, run_nb)
    
    if runtype == 'static':
        buckets_train = np.zeros(n_trains)
        pulsepattern = ['image', 'intradark']
        buckets_pulse = pulsepattern * (fpt // len(pulsepattern))
    
    if runtype == 'energyscan':
        buckets_train = tb.get_array(run_obj, 'nrj', 0.1).values
        pulsepattern = ['image', 'intradark']
        buckets_pulse = pulsepattern * (fpt // len(pulsepattern))
    
    if runtype == 'static_IR':
        buckets_train = np.zeros(n_trains)
        pulsepattern = ['unpumped', 'unpumped_intradark', 'pumped', 'pumped_intradark']
        buckets_pulse = pulsepattern * (fpt // len(pulsepattern))
        
    if runtype == 'energyscan_pumped':
        buckets_train = tb.get_array(run_obj, 'nrj', 0.1).values
        pulsepattern = ['unpumped', 'unpumped_intradark', 'pumped', 'pumped_intradark']
        buckets_pulse = pulsepattern * (fpt // len(pulsepattern))
        
    if runtype == 'delayscan':
        buckets_train = tb.get_array(run_obj, 'PP800_DelayLine', 0.03).values
        pulsepattern = ['unpumped', 'unpumped_intradark', 'pumped', 'pumped_intradark']
        buckets_pulse = pulsepattern * (fpt // len(pulsepattern))
    
    if runtype == 'timescan':  # 10s bins (tstamp is in ns)
        bin_nsec = 10 * 1e9
        tstamp = run_obj.get_array('SCS_RR_UTC/TSYS/TIMESERVER', 'id.timestamp')
        buckets_train = (bin_nsec * np.round(tstamp / bin_nsec) - tstamp.min()) / 1e9
        pulsepattern = ['unpumped', 'unpumped_intradark', 'pumped', 'pumped_intradark']
        buckets_pulse = pulsepattern * (fpt // len(pulsepattern))
    # -------------------------------------------------------------------------

    # create binner
    binner1 = tbdet.create_dssc_bins("trainId",trainIds,buckets_train)
    binner2 = tbdet.create_dssc_bins("pulse",
                                     np.linspace(0,fpt-1,fpt, dtype=int),
                                     buckets_pulse)
    binners = {'trainId': binner1, 'pulse': binner2}
    bin_obj = tbdet.DSSCBinner(proposal_nb, run_nb,
                               binners=binners,
                               dssc_coords_stride=normevery)
    
    if xgm_mask:
        bin_obj.create_pulsemask('xgm', xgm_threshold)

    dark=None
    if filename_dark:
        dark = tbdet.load_xarray(filename_dark)
        dark = dark['data']

    bin_params = {'modules':mod_list,
                  'chunksize':248,
                  'filepath':path,
                  'xgm_normalization':xgm_normalization,
                  'normevery':normevery,
                  'dark_image':dark}

    log_root.info("start binning routine")
    bin_obj.process_data(**bin_params)

    log_root.info("Add additional data to module files")
    if save_xgm_binned:
        bin_obj.load_xgm()
        xgm_binned = bin_obj.get_xgm_binned()

    if not os.path.isdir(path):
        os.mkdir(path)
    for m in mod_list:
        fname = f'run_{run_nb}_module{m}.h5'
        if save_xgm_binned:
            tbdet.save_xarray(
                path+fname, xgm_binned, group='xgm_binned', mode='a')
        tbdet.save_xarray(path+fname, binner1, group='binner1', mode='a')
        tbdet.save_xarray(path+fname, binner2, group='binner2', mode='a')
        metadata = {'run_number':run_nb,
                    'module':m,
                    'run_description':run_description}
        tbdet.save_attributes_h5(path+fname, metadata)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--run-number', metavar='S',
                        action='store',
                        help='the run to be processed')
    parser.add_argument('--module', metavar='S',
                        nargs='+', action='store',
                        help='modules to be processed')
    parser.add_argument('--runtype', metavar='S',
                        nargs='+', action='store',
                        help=('type of run (static, static_IR, energyscan, energyscan_pumped)'
                              ', delayscan', 'timescan)'))
    args = parser.parse_args()
    
    runtype = args.runtype[0]
    if args.run_number:
        if args.module is not None:
            modules = []
            if len(args.module) == 1:
                args.module = args.module[0].split(" ")
            modules = list(map(int, args.module))
            process(str(args.run_number), runtype, modules)
        else:
            process(str(args.run_number), runtype)