import os, sys
from time import time
from pathlib import Path
from datetime import datetime
from h5py import File

from concurrent.futures import ProcessPoolExecutor
    
import numpy as np
from extra_data import open_run

from .algos import WelfordAlg, do_stats, do_dense, do_sparse, do_traditional


from argparse import ArgumentParser, Namespace


runtypes = { 'stats'           : '00',
            'denseTTC'         : '1a',
            'denseTTCSymmNorm' : '1b',
            'sparseTTC'        : '2a',
            'sparseTTCSymmNorm': '2b',
            'traditional'      : '3' 
        }

trainMath = { 'stats'            : (do_stats, ['tids', 'mean', 'var', 'count']),
              'denseTTC'         : (do_dense, ['tids', 'ttc_on', 'ttc_off', 's_on', 's_cnt']),
              'denseTTCSymmNorm' : '1b',
              'sparseTTC'        : (do_sparse, ['tids', 'ttc_on', 'ttc_off', 's_on', 's_off', 's_cnt']),
#             'sparseTTCSymmNorm': '2b',
              'traditional'      : (do_traditional, ['tids', 'ttc_on', 'ttc_off', 's_on', 's_off', 's_cnt'])
            }




pars = ArgumentParser(prog='processOnCPU')


pars.add_argument('--user', default='braussef')

pars.add_argument('--proposal', '-p', dest='proposal', default=2853, type=int )
pars.add_argument('--runs',     '-r', dest='runs',     default=[70], type=int,
                   nargs='*')
pars.add_argument('--type',     '-t', dest='type',     default='stats', 
                  choices=list(runtypes) )   


pars.add_argument('--cal_mask',     '-cmsk', dest='cal_mask',
                  action="store_true")  # load/use mask from calibration pipeline

pars.add_argument('--photonize',     '-phot', dest='photonize',
                  action="store_true")  # photonize data
pars.add_argument('--debug',     '-dbg', dest='debug',
                  action="store_true")  # serialize for debugging
pars.add_argument('--aggregate-modules',     '-agg', dest='agg', default=False,
                  action="store_true")  

pars.add_argument('--energy',     '-e', dest='eng',     default=9.0,
                  type=float )                                    
pars.add_argument('--adu_per_keV',     '-adu', dest='adu_per_keV',     default=1.0,
                  type=float )             

pars.add_argument('--skip_after',     '-sm', dest='skip_after',     default=256,
                  type=int )             


pars.add_argument('--filter_trains',     '-ft', dest='filter_trains',     
                  action="store_true")  #  process only trainIds in good_trains.npy

pars.add_argument('--out_dir',     '-d', dest='out_dir',     default=None,
                  type=str )                   

pars.add_argument('--pulse_slice', dest='pulse_slice', default=[0, 352, 1], type=int, nargs=3)



def getModNo(hdf_file):
    return int([frag for frag in hdf_file.split('-') if 'AGIPD' in frag][0][-2:])

def getFileNo(hdf_file):
    return int([frag.split('.')[0][-4:] for frag in hdf_file.split('-') if 'S000' in frag][0])


def sliceTrains(hdf_file, module_number, args, good_trains):
#   global usr_mask, good_trains
    
    i    = 0
    tIds = np.unique(hdf_file['INSTRUMENT/MID_DET_AGIPD1M-1/DET/%iCH0:xtdf/image/trainId'% module_number][()])
    tIds = tIds[tIds>0]
    det  = hdf_file['INSTRUMENT/MID_DET_AGIPD1M-1/DET/%iCH0:xtdf/image/data'      % module_number]
    
    if args.cal_mask:
        cmsk = hdf_file['INSTRUMENT/MID_DET_AGIPD1M-1/DET/%iCH0:xtdf/image/mask'  % module_number]
    
    nTrains = len(tIds)
    nCells  = len(det) // nTrains
    #print(hdf_file,module_number,nTrains,nCells)
    #return

    #if nCells not in [64, 352]: return None
    
    a = np.zeros( (nCells, 512, 128), dtype=det.dtype ) 
    b = np.zeros( (nCells, 512, 128), dtype=det.dtype ) 
        
    #print('skip after '+str(args.skip_after))
    
    while i < nTrains and i <=args.skip_after:
        tid = tIds[i]

        if tid in good_trains:
            det.read_direct(a, source_sel=np.s_[i*nCells : (i+1)*nCells, :, :])

            if args.photonize:
                _a  = a / args.eng
                _a /= args.adu_per_keV
                _a = np.rint(_a) # `rint` does the same thing as `round`, but faster!
                _a[_a<0] = 0
                _a = _a[slice(*args.pulse_slice)].astype('int32').copy()

            else:
                _a = a[slice(*args.pulse_slice)] .astype('int32').copy() 

            if args.cal_mask:
                cmsk.read_direct(b, source_sel=np.s_[i*nCells : (i+1)*nCells, :, :])

            _b = (b>0)[slice(*args.pulse_slice)].astype(bool).copy()
            
            yield tid, _a, _b
        
        i += 1
        
        

def parallel_read(hdf_file, args, masks, good_trains):
#   global args, masks
    
    print(hdf_file)
    f = File(hdf_file)
    modNo = getModNo(hdf_file)
    fNo   = getFileNo(hdf_file)
    
    it = sliceTrains(f, modNo, args, good_trains)

    tic = time()
    mtype = bool if args.type == 'traditional' else np.int8
    m = masks[modNo].astype(mtype).copy()
    
    do_math, keys = trainMath[args.type]
    
    if np.any(m):
        try:
            res = do_math(it, m)
        except StopIteration:
            res = None

    else:
        res = None

    toc = time()

    print(f'Done with {hdf_file:s} after {toc-tic} secs')

    return (modNo, fNo), res, toc - tic


def run_analysis(args: Namespace):
    home_dir = '/beegfs/desy/user'

    for r in args.runs:

        print(args.cal_mask)
        print(args.filter_trains)
        
        
        if args.out_dir is None:
            proc_dir = f'xpcs/p{args.proposal:4d}/r{r:04d}/'

            path = '/'.join([home_dir, args.user, proc_dir])
            print(path)
        else:
            path = args.out_dir

        path = Path(path)

        if not path.exists():
            path.mkdir(parents=True, exist_ok=True)
        
        print(path)

        _, keys = trainMath[args.type]
        if args.type == 'stats':
            masks = np.ones( (16, 1, 512, 128) ) 
        else:
            masks = np.load(path / 'all_masks.npy').swapaxes(0,1).copy()
            

        run = open_run(proposal=args.proposal, run=r, data='all')
        
        
        if args.filter_trains:
            good_trains_path = path / "good_trains.npy"
            print(f"Loading good trains from {good_trains_path}")
            good_trains = np.load(good_trains_path)
        else:
            good_trains = run.train_ids
        
        all_files =  [f.filename for f in run.files if 'AGIPD' in f.filename and 'CTRL' not in f.filename]  

        log = open(path / 'log.txt', 'w')
        log.write(f'Starting at {str(datetime.now()):s}\n')
        log.write( 'Running on ' + os.environ['HOSTNAME'] + ' with ' + str(os.cpu_count()) + ' cores \n' ) 
        log.write( str(len(all_files)) + ' files in total \n' ) 



        walltime = time()
        
        print('debug is '+str(args.debug))

        if args.debug:
            for F in all_files:
                parallel_read(F, args, masks, good_trains)

        else:
            from itertools import repeat
            with ProcessPoolExecutor( os.cpu_count()) as exc:
                res = exc.map(parallel_read, all_files, repeat(args), repeat(masks), repeat(good_trains), )


        walltime = time() - walltime
        log.write(f'WALL TIME is {walltime:6.1f} secs\n')
        log.write(f'Finished at {str(datetime.now()):s}\n')
        log.close()

        values  = dict()
        timings = dict()
       

        for k,v,t in res: #(modNo, fNo), res, toc - tic
            if v is not None:
                values[k] = v
                timings[k] = t
              

### sum over modules
        if args.agg:
            print(timings)
            all_data = dict()

            for k in keys:
                all_data[k] = dict()


            for k in sorted(values):
                mod, fNo = k
                mod, fNo = f'mod{mod:02d}', f'file{fNo:02d}'

                for j, v in zip(keys, values[k]):
                    if not isinstance(v, np.ndarray):
                        v = np.asarray(v, dtype=v[0].dtype)
                    if not fNo in all_data[j]:
                        all_data[j][fNo] = v
                    elif j != 'tids':
                        all_data[j][fNo] += v

            with File(path / f'output_{runtypes[args.type]:s}_{args.type:s}.h5', 'w') as f:
                for j in sorted(all_data):
                    dset = all_data[j]
                    v = np.concatenate(list(dset.values()))
                    d = f.create_dataset(j, shape=v.shape, dtype=v.dtype)
                    d[:] = v
#                   d = f.create_dataset('tictoc', dtype=np.float32, shape=1)
#                   d[:] = timings[k]

### don't sum
        else:
            with File(path / f'output_{runtypes[args.type]:s}_{args.type:s}.h5', 'w') as f:
                for k in sorted(values):
                    mod, fNo = k
                    mod, fNo = f'mod{mod:02d}', f'file{fNo:02d}'

                    if mod not in f.keys():
                        f.create_group(mod)

                    m = f[mod]

                    if fNo not in m.keys():
                        m.create_group(fNo)

                    for j, v in zip(keys, values[k]):

                        if not isinstance(v, np.ndarray):
                            v = np.asarray(v, dtype=v[0].dtype)
                        d = m[fNo].create_dataset(j, shape = v.shape, dtype=v.dtype)
                        d[:] = v

                    d = m[fNo].create_dataset('tictoc', dtype=np.float32, shape=1)
                    d[:] = timings[k]




if __name__ == '__main__':
    args = pars.parse_args()

    run_analysis(args)