{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DSSC Offline Correction #\n",
    "\n",
    "Author: European XFEL Detector Group, Version: 1.0\n",
    "\n",
    "Offline Calibration for the DSSC Detector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-21T11:30:06.730220Z",
     "start_time": "2019-02-21T11:30:06.658286Z"
    }
   },
   "outputs": [],
   "source": [
    "cluster_profile = \"noDB\" # The ipcluster profile to use\n",
    "in_folder = \"/gpfs/exfel/exp/SCS/202031/p900170/raw\" # path to input data, required\n",
    "out_folder = \"/gpfs/exfel/data/scratch/samartse/test/DSSC\" # path to output to, required\n",
    "sequences = [-1] # sequence files to evaluate.\n",
    "modules = [-1] # modules to correct, set to -1 for all, range allowed\n",
    "run = 229 #runs to process, required\n",
    "\n",
    "karabo_id = \"SCS_DET_DSSC1M-1\" # karabo karabo_id\n",
    "karabo_da = ['-1']  # a list of data aggregators names, Default [-1] for selecting all data aggregators\n",
    "receiver_id = \"{}CH0\" # inset for receiver devices\n",
    "path_template = 'RAW-R{:04d}-{}-S{:05d}.h5' # the template to use to access data\n",
    "h5path = 'INSTRUMENT/{}/DET/{}:xtdf/image' # path in the HDF5 file to images\n",
    "h5path_idx = '/INDEX/{}/DET/{}:xtdf/image' # path in the HDF5 file to images\n",
    "slow_data_pattern = 'RAW-R{}-DA{}-S00000.h5'\n",
    "\n",
    "use_dir_creation_date = True # use the creation data of the input dir for database queries\n",
    "cal_db_interface = \"tcp://max-exfl016:8020#8025\" # the database interface to use\n",
    "cal_db_timeout = 300000 # in milli seconds\n",
    "\n",
    "mem_cells = 0 # number of memory cells used, set to 0 to automatically infer\n",
    "overwrite = True # set to True if existing data should be overwritten\n",
    "max_pulses = 800 # maximum number of pulses per train\n",
    "bias_voltage = 100 # detector bias voltage\n",
    "sequences_per_node = 1 # number of sequence files per cluster node if run as slurm job, set to 0 to not run SLURM parallel\n",
    "chunk_size_idim = 1  # chunking size of imaging dimension, adjust if user software is sensitive to this.\n",
    "mask_noisy_asic = 0.25 # set to a value other than 0 and below 1 to mask entire ADC if fraction of noisy pixels is above\n",
    "mask_cold_asic = 0.25 # mask cold ASICS if number of pixels with negligable standard deviation is larger than this fraction\n",
    "noisy_pix_threshold = 1. # threshold above which ap pixel is considered noisy.\n",
    "geo_file = \"/gpfs/exfel/data/scratch/xcal/dssc_geo_june19.h5\" # detector geometry file\n",
    "dinstance = \"DSSC1M1\"\n",
    "slow_data_aggregators = [1,2,3,4] #quadrant/aggregator\n",
    "\n",
    "def balance_sequences(in_folder, run, sequences, sequences_per_node, karabo_da):\n",
    "    from xfel_calibrate.calibrate import balance_sequences as bs\n",
    "    return bs(in_folder, run, sequences, sequences_per_node, karabo_da)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-21T11:30:07.086286Z",
     "start_time": "2019-02-21T11:30:06.929722Z"
    }
   },
   "outputs": [],
   "source": [
    "# make sure a cluster is running with ipcluster start --n=32, give it a while to start\n",
    "import os\n",
    "import sys\n",
    "from collections import OrderedDict\n",
    "\n",
    "import h5py\n",
    "import matplotlib\n",
    "import numpy as np\n",
    "\n",
    "matplotlib.use(\"agg\")\n",
    "import matplotlib.pyplot as plt\n",
    "from ipyparallel import Client\n",
    "from IPython.display import Latex, Markdown, display\n",
    "\n",
    "print(f\"Connecting to profile {cluster_profile}\")\n",
    "view = Client(profile=cluster_profile)[:]\n",
    "view.use_dill()\n",
    "\n",
    "from datetime import timedelta\n",
    "\n",
    "from cal_tools.dssclib import get_dssc_ctrl_data, get_pulseid_checksum\n",
    "from cal_tools.tools import (\n",
    "    get_constant_from_db,\n",
    "    get_dir_creation_date,\n",
    "    get_notebook_name,\n",
    "    map_modules_from_folder,\n",
    "    parse_runs,\n",
    "    run_prop_seq_from_path,\n",
    ")\n",
    "from dateutil import parser\n",
    "from iCalibrationDB import Conditions, ConstantMetaData, Constants, Detectors, Versions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "creation_time = None\n",
    "if use_dir_creation_date:\n",
    "    creation_time = get_dir_creation_date(in_folder, run)\n",
    "    print(f\"Using {creation_time} as creation time\")\n",
    "\n",
    "if sequences[0] == -1:\n",
    "    sequences = None\n",
    "    \n",
    "h5path = h5path.format(karabo_id, receiver_id)\n",
    "h5path_idx = h5path_idx.format(karabo_id, receiver_id)\n",
    "\n",
    "\n",
    "if karabo_da[0] == '-1':\n",
    "    if modules[0] == -1:\n",
    "        modules = list(range(16))\n",
    "    karabo_da = [\"DSSC{:02d}\".format(i) for i in modules]\n",
    "else:\n",
    "    modules = [int(x[-2:]) for x in karabo_da]\n",
    "print(\"Process modules: \", \n",
    "      ', '.join([f\"Q{x // 4 + 1}M{x % 4 + 1}\" for x in modules]))\n",
    "\n",
    "CHUNK_SIZE = 512\n",
    "MAX_PAR = 32\n",
    "\n",
    "if in_folder[-1] == \"/\":\n",
    "    in_folder = in_folder[:-1]\n",
    "print(f\"Outputting to {out_folder}\")\n",
    "\n",
    "if not os.path.exists(out_folder):\n",
    "    os.makedirs(out_folder)\n",
    "elif not overwrite:\n",
    "    raise AttributeError(\"Output path exists! Exiting\")\n",
    "\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "print(f\"Detector in use is {karabo_id}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-21T11:30:07.974174Z",
     "start_time": "2019-02-21T11:30:07.914832Z"
    }
   },
   "outputs": [],
   "source": [
    "# set everything up filewise\n",
    "mmf = map_modules_from_folder(in_folder, run, path_template, karabo_da, sequences)\n",
    "mapped_files, mod_ids, total_sequences, sequences_qm, file_size = mmf\n",
    "MAX_PAR = min(MAX_PAR, total_sequences)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Processed Files ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-21T11:30:08.870802Z",
     "start_time": "2019-02-21T11:30:08.826285Z"
    }
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "\n",
    "import tabulate\n",
    "from IPython.display import HTML, Latex, Markdown, display\n",
    "\n",
    "print(f\"Processing a total of {total_sequences} sequence files in chunks of {MAX_PAR}\")\n",
    "table = []\n",
    "mfc = copy.copy(mapped_files)\n",
    "ti = 0\n",
    "for k, files in mfc.items():\n",
    "    i = 0\n",
    "    while not files.empty():\n",
    "        f = files.get()\n",
    "        if i == 0:\n",
    "            table.append((ti, k, i, f))\n",
    "        else:\n",
    "            table.append((ti, \"\", i,  f))\n",
    "        i += 1\n",
    "        ti += 1\n",
    "if len(table):\n",
    "    md = display(Latex(tabulate.tabulate(table, tablefmt='latex', headers=[\"#\", \"module\", \"# module\", \"file\"])))      \n",
    "# restore the queue\n",
    "mmf = map_modules_from_folder(in_folder, run, path_template, karabo_da, sequences)\n",
    "mapped_files, mod_ids, total_sequences, sequences_qm, file_size = mmf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-21T11:30:16.057429Z",
     "start_time": "2019-02-21T11:30:10.082114Z"
    }
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "from functools import partial\n",
    "\n",
    "\n",
    "def correct_module(total_sequences, sequences_qm, karabo_id, dinstance, mask_noisy_asic, \n",
    "                   mask_cold_asic, noisy_pix_threshold, chunksize, mem_cells, bias_voltage,\n",
    "                   cal_db_timeout, creation_time, cal_db_interface, h5path, h5path_idx, inp):\n",
    "   \n",
    "    import binascii\n",
    "    import copy\n",
    "    import struct\n",
    "    from hashlib import blake2b\n",
    "\n",
    "    import h5py\n",
    "    import numpy as np\n",
    "    from cal_tools.dssclib import get_dssc_ctrl_data, get_pulseid_checksum\n",
    "    from cal_tools.enums import BadPixels\n",
    "    from cal_tools.tools import get_constant_from_db_and_time\n",
    "    from iCalibrationDB import (\n",
    "        Conditions,\n",
    "        ConstantMetaData,\n",
    "        Constants,\n",
    "        Detectors,\n",
    "        Versions,\n",
    "    )\n",
    "    \n",
    "    filename, filename_out, channel, karabo_da, qm, conditions = inp\n",
    "    \n",
    "    # DSSC correction requires path without the leading \"/\"\n",
    "    if h5path[0] == '/':\n",
    "        h5path = h5path[1:]\n",
    "    if h5path_idx[0] == '/':\n",
    "        h5path_idx = h5path_idx[1:]\n",
    "\n",
    "    h5path = h5path.format(channel)\n",
    "    h5path_idx = h5path_idx.format(channel)\n",
    "    \n",
    "    low_edges = None\n",
    "    hists_signal_low = None\n",
    "    high_edges = None\n",
    "    hists_signal_high = None\n",
    "    pulse_edges = None\n",
    "    err = None\n",
    "    offset_not_found = False\n",
    "    def get_num_cells(fname, h5path):\n",
    "        with h5py.File(fname, \"r\") as f:\n",
    "\n",
    "            cells = f[f\"{h5path}/cellId\"][()]\n",
    "            maxcell = np.max(cells)\n",
    "            options = [100, 200, 400, 500, 600, 700, 800]\n",
    "            dists = np.array([(o-maxcell) for o in options])\n",
    "            dists[dists<0] = 10000 # assure to always go higher\n",
    "            return options[np.argmin(dists)]\n",
    "        \n",
    "    if mem_cells == 0:\n",
    "        mem_cells = get_num_cells(filename, h5path)\n",
    "        \n",
    "    pulseid_checksum = get_pulseid_checksum(filename, h5path, h5path_idx)\n",
    "        \n",
    "    print(f\"Memcells: {mem_cells}\")\n",
    "    \n",
    "    condition =  Conditions.Dark.DSSC(bias_voltage=bias_voltage, memory_cells=mem_cells,\\\n",
    "                                      pulseid_checksum=pulseid_checksum,\\\n",
    "                                      acquisition_rate=conditions['acquisition_rate'],\\\n",
    "                                      target_gain=conditions['target_gain'],\\\n",
    "                                      encoded_gain=conditions['encoded_gain'])\n",
    "    \n",
    "    detinst = getattr(Detectors, dinstance)\n",
    "    device = getattr(detinst, qm)\n",
    "    with h5py.File(filename, \"r\", driver=\"core\") as infile:\n",
    "        y = infile[f\"{h5path}/data\"].shape[2]\n",
    "        x = infile[f\"{h5path}/data\"].shape[3]\n",
    "    offset, when = get_constant_from_db_and_time(karabo_id, karabo_da,\n",
    "                                                 Constants.DSSC.Offset(),\n",
    "                                                 condition,\n",
    "                                                 None,\n",
    "                                                 cal_db_interface,\n",
    "                                                 creation_time=creation_time,\n",
    "                                                 timeout=cal_db_timeout)\n",
    "    if offset is not None:\n",
    "        offset = np.moveaxis(np.moveaxis(offset[...], 2, 0), 2, 1)\n",
    "    else:\n",
    "        offset_not_found = True\n",
    "        print(\"No offset found in the database\")\n",
    "    \n",
    "    def copy_and_sanitize_non_cal_data(infile, outfile):\n",
    "        # these are touched in the correct function, do not copy them here\n",
    "        dont_copy = [\"data\"]\n",
    "        dont_copy = [h5path + \"/{}\".format(do)\n",
    "                     for do in dont_copy]\n",
    "\n",
    "        # a visitor to copy everything else\n",
    "        def visitor(k, item):\n",
    "            if k not in dont_copy:\n",
    "\n",
    "                if isinstance(item, h5py.Group):\n",
    "                    outfile.create_group(k)\n",
    "                elif isinstance(item, h5py.Dataset):\n",
    "                    group = str(k).split(\"/\")\n",
    "                    group = \"/\".join(group[:-1])\n",
    "                    infile.copy(k, outfile[group])\n",
    "\n",
    "        infile.visititems(visitor)\n",
    "\n",
    "    try:\n",
    "        with h5py.File(filename, \"r\", driver=\"core\") as infile:\n",
    "            with h5py.File(filename_out, \"w\") as outfile:\n",
    "                copy_and_sanitize_non_cal_data(infile, outfile)\n",
    "                # get indices of last images in each train\n",
    "                first_arr = np.squeeze(infile[f\"{h5path_idx}/first\"]).astype(np.int)\n",
    "                last_arr = np.concatenate((first_arr[1:], np.array([-1,]))).astype(np.int)\n",
    "                assert first_arr.size == last_arr.size\n",
    "                oshape = list(infile[f\"{h5path}/data\"].shape)\n",
    "                if len(oshape) == 4:\n",
    "                    oshape = [oshape[0],]+oshape[2:]\n",
    "                chunks = (chunksize, oshape[1], oshape[2])\n",
    "                ddset = outfile.create_dataset(f\"{h5path}/data\",\n",
    "                                               oshape, chunks=chunks,\n",
    "                                               dtype=np.float32,\n",
    "                                               fletcher32=True)\n",
    "\n",
    "                mdset = outfile.create_dataset(f\"{h5path}/mask\",\n",
    "                                               oshape, chunks=chunks,\n",
    "                                               dtype=np.uint32,\n",
    "                                               compression=\"gzip\",\n",
    "                                               compression_opts=1,\n",
    "                                               shuffle=True,\n",
    "                                               fletcher32=True)\n",
    "\n",
    "                for train in range(first_arr.size):\n",
    "                    first = first_arr[train]\n",
    "                    last = last_arr[train]\n",
    "                    if first == last:\n",
    "                        continue\n",
    "                    data = np.squeeze(infile[f\"{h5path}/data\"][first:last, ...].astype(np.float32))\n",
    "                    cellId = np.squeeze(infile[f\"{h5path}/cellId\"][first:last, ...])\n",
    "                    pulseId = np.squeeze(infile[f\"{h5path}/pulseId\"][first:last, ...])                   \n",
    "                    if not offset_not_found:\n",
    "                        data[...] -= offset[cellId,...]\n",
    "                        \n",
    "                    if hists_signal_low is None:\n",
    "                        pulseId = np.repeat(pulseId[:, None], data.shape[1], axis=1)\n",
    "                        pulseId = np.repeat(pulseId[:,:,None], data.shape[2], axis=2)\n",
    "                        bins = (55, int(pulseId.max()))\n",
    "                        rnge = [[-5, 50], [0, int(pulseId.max())]]\n",
    "                        hists_signal_low, low_edges, pulse_edges = np.histogram2d(data.flatten(),\n",
    "                                                                                  pulseId.flatten(),\n",
    "                                                                                  bins=bins,\n",
    "                                                                                  range=rnge)\n",
    "                        rnge = [[-5, 300], [0, pulseId.max()]]\n",
    "                        hists_signal_high, high_edges, _ = np.histogram2d(data.flatten(),\n",
    "                                                                          pulseId.flatten(),\n",
    "                                                                          bins=bins,\n",
    "                                                                          range=rnge)\n",
    "                    ddset[first:last, ...] = data\n",
    "                \n",
    "                # find static and noisy values in dark images\n",
    "                data = infile[f\"{h5path}/data\"][last, ...].astype(np.float32)\n",
    "                bpix = np.zeros(oshape[1:], np.uint32)\n",
    "                dark_std = np.std(data, axis=0)\n",
    "                bpix[dark_std > noisy_pix_threshold] = BadPixels.NOISE_OUT_OF_THRESHOLD.value\n",
    "\n",
    "                for i in range(8):\n",
    "                    for j in range(2):\n",
    "                        count_noise = np.count_nonzero(bpix[i*64:(i+1)*64, j*64:(j+1)*64])\n",
    "                        asic_std = np.std(data[:, i*64:(i+1)*64, j*64:(j+1)*64])\n",
    "                        if mask_noisy_asic:\n",
    "                            if count_noise/(64*64) > mask_noisy_asic:\n",
    "                                bpix[i*64:(i+1)*64, j*64:(j+1)*64] = BadPixels.NOISY_ADC.value\n",
    "                    \n",
    "                        if mask_cold_asic:\n",
    "                            count_cold = np.count_nonzero(asic_std < 0.5)\n",
    "                            if count_cold/(64*64) > mask_cold_asic:\n",
    "                                bpix[i*64:(i+1)*64, j*64:(j+1)*64] = BadPixels.ASIC_STD_BELOW_NOISE.value\n",
    "\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        success = False\n",
    "        reason = \"Error\"\n",
    "        err = e\n",
    "   \n",
    "    if err is None and offset_not_found:\n",
    "        err = \"Offset not found in database!. No offset correction applied.\"\n",
    "        \n",
    "    return (hists_signal_low, hists_signal_high, low_edges, high_edges, pulse_edges, when, qm, err)\n",
    "    \n",
    "done = False\n",
    "first_files = {}\n",
    "inp = []\n",
    "left = total_sequences\n",
    "\n",
    "hists_signal_low = 0\n",
    "hists_signal_high = 0 \n",
    "\n",
    "low_edges, high_edges, pulse_edges = None, None, None\n",
    "\n",
    "tGain, encodedGain, operatingFreq = get_dssc_ctrl_data(in_folder\\\n",
    "                                + \"/r{:04d}/\".format(run),\\\n",
    "                                slow_data_pattern,slow_data_aggregators, run)\n",
    "\n",
    "whens = []\n",
    "qms = []\n",
    "Errors = []\n",
    "while not done:\n",
    "    dones = []\n",
    "    for i, k_da in zip(modules, karabo_da):\n",
    "        qm = \"Q{}M{}\".format(i//4 +1, i % 4 + 1)\n",
    "\n",
    "        if qm in mapped_files:\n",
    "            if not mapped_files[qm].empty():\n",
    "                fname_in = str(mapped_files[qm].get())\n",
    "                dones.append(mapped_files[qm].empty())\n",
    "            else:\n",
    "                print(f\"{qm} file is missing\")\n",
    "                continue\n",
    "        else:\n",
    "            print(f\"Skipping {qm}\")\n",
    "            continue\n",
    "        fout = os.path.abspath(\"{}/{}\".format(out_folder, (os.path.split(fname_in)[-1]).replace(\"RAW\", \"CORR\")))\n",
    "        \n",
    "        first_files[i] = (fname_in, fout)\n",
    "        conditions = {}\n",
    "        conditions['acquisition_rate'] = operatingFreq[qm]\n",
    "        conditions['target_gain'] = tGain[qm]\n",
    "        conditions['encoded_gain'] = encodedGain[qm]\n",
    "        inp.append((fname_in, fout, i, k_da, qm, conditions))\n",
    "        \n",
    "    if len(inp) >= min(MAX_PAR, left):\n",
    "        print(f\"Running {len(inp)} tasks parallel\")\n",
    "        p = partial(correct_module, total_sequences, sequences_qm,\n",
    "                    karabo_id, dinstance, mask_noisy_asic, mask_cold_asic,\n",
    "                    noisy_pix_threshold, chunk_size_idim, mem_cells,\n",
    "                    bias_voltage, cal_db_timeout, creation_time, cal_db_interface,\n",
    "                    h5path, h5path_idx)\n",
    "\n",
    "        r = view.map_sync(p, inp)\n",
    "        #r = list(map(p, inp))\n",
    "\n",
    "        inp = []\n",
    "        left -= MAX_PAR\n",
    "        \n",
    "        for rr in r:\n",
    "            if rr is not None:\n",
    "                hl, hh, low_edges, high_edges, pulse_edges, when, qm, err = rr\n",
    "                whens.append(when)\n",
    "                qms.append(qm)\n",
    "                Errors.append(err)\n",
    "                if hl is not None:  # any one being None will also make the others None\n",
    "                    hists_signal_low += hl.astype(np.float64)\n",
    "                    hists_signal_high += hh.astype(np.float64)                \n",
    "    \n",
    "    done = all(dones)\n",
    "\n",
    "whens = [x for _,x in sorted(zip(qms,whens))]\n",
    "qms = sorted(qms)\n",
    "for i, qm in enumerate(qms):\n",
    "    try:\n",
    "        when = whens[i].isoformat()\n",
    "    except:\n",
    "        when = whens[i]\n",
    "    if Errors[i] is not None:\n",
    "\n",
    "        # Avoid writing wrong injection date if cons. not found.\n",
    "        if \"not found\" in str(Errors[i]):\n",
    "            print(f\"ERROR! {qm}: {Errors[i]}\")\n",
    "        else:\n",
    "            print(f\"Offset for {qm} was injected on {when}, ERROR!: {Errors[i]}\")\n",
    "    else:\n",
    "        print(f\"Offset for {qm} was injected on {when}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:28:51.765030Z",
     "start_time": "2019-02-18T17:28:51.714783Z"
    }
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from matplotlib import cm\n",
    "from matplotlib.ticker import FormatStrFormatter, LinearLocator\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "\n",
    "%matplotlib inline\n",
    "def do_3d_plot(data, edges, x_axis, y_axis):\n",
    "    fig = plt.figure(figsize=(10,10))\n",
    "    ax = fig.gca(projection='3d')\n",
    "\n",
    "    # Make data.\n",
    "    X = edges[0][:-1]\n",
    "    Y = edges[1][:-1]\n",
    "    X, Y = np.meshgrid(X, Y)\n",
    "\n",
    "    Z = data.T\n",
    "\n",
    "    # Plot the surface.\n",
    "    surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,\n",
    "                           linewidth=0, antialiased=False)\n",
    "    ax.set_xlabel(x_axis)\n",
    "    ax.set_ylabel(y_axis)\n",
    "    ax.set_zlabel(\"Counts\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:28:53.690522Z",
     "start_time": "2019-02-18T17:28:52.860143Z"
    }
   },
   "outputs": [],
   "source": [
    "def do_2d_plot(data, edges, y_axis, x_axis):\n",
    "    from matplotlib.colors import LogNorm\n",
    "    fig = plt.figure(figsize=(10,10))\n",
    "    ax = fig.add_subplot(111)\n",
    "    extent = [np.min(edges[1]), np.max(edges[1]),np.min(edges[0]), np.max(edges[0])]\n",
    "    im = ax.imshow(data[::-1,:], extent=extent, aspect=\"auto\", norm=LogNorm(vmin=1, vmax=np.max(data)))\n",
    "    ax.set_xlabel(x_axis)\n",
    "    ax.set_ylabel(y_axis)\n",
    "    cb = fig.colorbar(im)\n",
    "    cb.set_label(\"Counts\")\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mean Intensity per Pulse ##\n",
    "\n",
    "The following plots show the mean signal for each pulse in a detailed and expanded intensity region."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:28:57.327702Z",
     "start_time": "2019-02-18T17:28:54.377061Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "do_3d_plot(hists_signal_low, [low_edges, pulse_edges], \"Signal (ADU)\", \"Pulse id\")\n",
    "do_2d_plot(hists_signal_low, [low_edges, pulse_edges], \"Signal (ADU)\", \"Pulse id\")\n",
    "do_3d_plot(hists_signal_high, [high_edges, pulse_edges], \"Signal (ADU)\", \"Pulse id\")\n",
    "do_2d_plot(hists_signal_high, [high_edges, pulse_edges], \"Signal (ADU)\", \"Pulse id\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:20.634480Z",
     "start_time": "2019-02-18T17:28:57.329231Z"
    }
   },
   "outputs": [],
   "source": [
    "corrected = []\n",
    "raw = []\n",
    "mask = []\n",
    "pulse_ids = []\n",
    "train_ids = [] \n",
    "for channel, ff in first_files.items():\n",
    "    try:\n",
    "        raw_file, corr_file = ff\n",
    "        data_path = h5path.format(channel)\n",
    "        index_path = h5path_idx.format(channel)\n",
    "        try:\n",
    "            infile = h5py.File(raw_file, \"r\")\n",
    "            first_idx = int(np.array(infile[f\"{index_path}/first\"])[0])\n",
    "            \n",
    "            raw_d = np.array(infile[f\"{data_path}/data\"])\n",
    "            # Use first 128 images for plotting\n",
    "            if raw_d.shape[0] >= 128:\n",
    "                # random number for plotting\n",
    "                plt_im = 128 \n",
    "            else:\n",
    "                plt_im = d.shape[0]\n",
    "            last_idx = first_idx + plt_im\n",
    "            raw.append((channel,raw_d[first_idx:last_idx,0,...]))\n",
    "        finally:\n",
    "            infile.close()\n",
    "        \n",
    "        infile = h5py.File(corr_file, \"r\")\n",
    "        try:\n",
    "            corrected.append((channel, np.array(infile[f\"{data_path}/data\"][first_idx:last_idx,...])))\n",
    "            mask.append((channel, np.array(infile[f\"{data_path}/mask\"][first_idx:last_idx,...])))\n",
    "            pulse_ids.append((channel, np.squeeze(infile[f\"{data_path}/pulseId\"][first_idx:last_idx,...])))\n",
    "            train_ids.append((channel, np.squeeze(infile[f\"{data_path}/trainId\"][first_idx:last_idx,...])))\n",
    "        finally:\n",
    "            infile.close()\n",
    "        \n",
    "    except Exception as e:\n",
    "        print(e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def combine_stack(d, sdim):\n",
    "    combined = np.zeros((sdim, 1300,1300), np.float32)\n",
    "    combined[...] = 0\n",
    "    \n",
    "    dy = 0\n",
    "    quad_pos = [\n",
    "        (0, 145),\n",
    "        (130, 140),\n",
    "        (125, 15),\n",
    "        (0, 15),\n",
    "        \n",
    "    ]\n",
    "    \n",
    "    px = 0.236\n",
    "    py = 0.204\n",
    "    with h5py.File(geo_file, \"r\") as gf:\n",
    "        # TODO: refactor to -> for ch, f in d:\n",
    "        for i in range(len(d)):\n",
    "            \n",
    "            ch = d[i][0]\n",
    "          \n",
    "            mi = 3-(ch%4)\n",
    "            mp = gf[\"Q{}/M{}/Position\".format(ch//4+1, mi%4+1)][()]\n",
    "            t1 = gf[\"Q{}/M{}/T01/Position\".format(ch//4+1, ch%4+1)][()]\n",
    "            t2 = gf[\"Q{}/M{}/T02/Position\".format(ch//4+1, ch%4+1)][()]\n",
    "            if ch//4 < 2:\n",
    "                t1, t2 = t2, t1\n",
    "            \n",
    "            if ch // 4 == 0 or ch // 4 == 1:\n",
    "                td = d[i][1][:,::-1,:]\n",
    "            else:\n",
    "                td = d[i][1][:,:,::-1]\n",
    "            \n",
    "            t1d = td[:,:,:256]\n",
    "            t2d = td[:,:,256:]\n",
    "            \n",
    "            x0t1 = int((t1[0]+mp[0])/px)\n",
    "            y0t1 = int((t1[1]+mp[1])/py)\n",
    "            x0t2 = int((t2[0]+mp[0])/px)\n",
    "            y0t2 = int((t2[1]+mp[1])/py)\n",
    "            \n",
    "            x0t1 += int(quad_pos[i//4][1]/px)\n",
    "            x0t2 += int(quad_pos[i//4][1]/px)\n",
    "            y0t1 += int(quad_pos[i//4][0]/py)+combined.shape[1]//16\n",
    "            y0t2 += int(quad_pos[i//4][0]/py)+combined.shape[1]//16\n",
    "            combined[:,y0t1:y0t1+128,x0t1:x0t1+256] = t1d\n",
    "            combined[:,y0t2:y0t2+128,x0t2:x0t2+256] = t2d\n",
    "\n",
    "    return combined"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:27.025667Z",
     "start_time": "2019-02-18T17:29:20.642029Z"
    }
   },
   "outputs": [],
   "source": [
    "combined = combine_stack(corrected, last_idx-first_idx)\n",
    "combined_raw = combine_stack(raw, last_idx-first_idx)\n",
    "combined_mask = combine_stack(mask, last_idx-first_idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mean RAW Preview ###\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(\"The per pixel mean of the first {} images of the RAW data\".format(plt_im)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:33.226396Z",
     "start_time": "2019-02-18T17:29:27.027758Z"
    }
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "fig = plt.figure(figsize=(20,10))\n",
    "ax = fig.add_subplot(111)\n",
    "im = ax.imshow(np.mean(combined_raw[:,...],axis=0),\n",
    "               vmin=min(0.75*np.median(combined_raw[combined_raw > 0]), -5),\n",
    "               vmax=max(1.5*np.median(combined_raw[combined_raw > 0]), 50), cmap=\"jet\")\n",
    "cb = fig.colorbar(im, ax=ax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Single Shot Preview ###\n",
    "\n",
    "A single shot image from cell 2 of the first train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:33.761015Z",
     "start_time": "2019-02-18T17:29:33.227922Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20,10))\n",
    "ax = fig.add_subplot(111)\n",
    "dim = combined[2,...]\n",
    "\n",
    "im = ax.imshow(dim, vmin=-0, vmax=max(1.5*np.median(dim[dim > 0]), 50), cmap=\"jet\", interpolation=\"nearest\")\n",
    "cb = fig.colorbar(im, ax=ax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:35.903487Z",
     "start_time": "2019-02-18T17:29:33.762568Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20,10))\n",
    "ax = fig.add_subplot(111)\n",
    "h = ax.hist(dim.flatten(), bins=100, range=(0, 100))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mean CORRECTED Preview ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(\"The per pixel mean of the first {} images of the CORRECTED data\".format(plt_im)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:39.369686Z",
     "start_time": "2019-02-18T17:29:35.905152Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20,10))\n",
    "ax = fig.add_subplot(111)\n",
    "im = ax.imshow(np.mean(combined[:,...], axis=0), vmin=0,\n",
    "               vmax=max(1.5*np.median(combined[combined > 0]), 10), cmap=\"jet\", interpolation=\"nearest\")\n",
    "cb = fig.colorbar(im, ax=ax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Max CORRECTED Preview ###\n",
    "\n",
    "The per pixel maximum of the first 128 images of the CORRECTED data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20,10))\n",
    "ax = fig.add_subplot(111)\n",
    "im = ax.imshow(np.max(combined[:,...], axis=0), vmin=0,\n",
    "               vmax=max(100*np.median(combined[combined > 0]), 20), cmap=\"jet\", interpolation=\"nearest\")\n",
    "cb = fig.colorbar(im, ax=ax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:49.217848Z",
     "start_time": "2019-02-18T17:29:39.371232Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20,10))\n",
    "ax = fig.add_subplot(111)\n",
    "combined[combined <= 0] = 0\n",
    "h = ax.hist(combined.flatten(), bins=100, range=(-5, 100), log=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Bad Pixels ##\n",
    "The mask contains dedicated entries for all pixels and memory cells as well as all three gains stages. Each mask entry is encoded in 32 bits as:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:49.651913Z",
     "start_time": "2019-02-18T17:29:49.643556Z"
    }
   },
   "outputs": [],
   "source": [
    "import tabulate\n",
    "from cal_tools.enums import BadPixels\n",
    "from IPython.display import HTML, Latex, Markdown, display\n",
    "\n",
    "table = []\n",
    "for item in BadPixels:\n",
    "    table.append((item.name, \"{:016b}\".format(item.value)))\n",
    "md = display(Latex(tabulate.tabulate(table, tablefmt='latex', headers=[\"Bad pixel type\", \"Bit mask\"])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### Full Train Bad Pixels ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:51.686562Z",
     "start_time": "2019-02-18T17:29:50.088883Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20,10))\n",
    "ax = fig.add_subplot(111)\n",
    "im = ax.imshow(np.log2(np.max(combined_mask[:,...], axis=0)), vmin=0,\n",
    "               vmax=32, cmap=\"jet\")\n",
    "cb = fig.colorbar(im, ax=ax)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Full Train Bad Pixels - Only Dark Char. Related ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:53.662423Z",
     "start_time": "2019-02-18T17:29:51.688376Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20,10))\n",
    "ax = fig.add_subplot(111)\n",
    "im = ax.imshow(np.max((combined_mask.astype(np.uint32)[:,...] & BadPixels.NOISY_ADC.value) != 0, axis=0), vmin=0,\n",
    "               vmax=1, cmap=\"jet\")\n",
    "cb = fig.colorbar(im, ax=ax)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}