{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AGIPD Offline Correction #\n",
    "\n",
    "Author: European XFEL Detector Group, Version: 2.0\n",
    "\n",
    "Offline Calibration for the AGIPD Detector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-21T11:30:06.730220Z",
     "start_time": "2019-02-21T11:30:06.658286Z"
    }
   },
   "outputs": [],
   "source": [
    "in_folder = \"/gpfs/exfel/exp/HED/202031/p900174/raw\" # the folder to read data from, required\n",
    "out_folder = \"/gpfs/exfel/data/scratch/ahmedk/test/hibef_agipd2\"  # the folder to output to, required\n",
    "sequences = [-1] # sequences to correct, set to -1 for all, range allowed\n",
    "modules = [-1] # modules to correct, set to -1 for all, range allowed\n",
    "run = 155 # runs to process, required\n",
    "\n",
    "karabo_id = \"HED_DET_AGIPD500K2G\" # karabo karabo_id\n",
    "karabo_da = ['-1']  # a list of data aggregators names, Default [-1] for selecting all data aggregators\n",
    "receiver_id = \"{}CH0\" # inset for receiver devices\n",
    "path_template = 'RAW-R{:04d}-{}-S{:05d}.h5' # the template to use to access data\n",
    "h5path = 'INSTRUMENT/{}/DET/{}:xtdf/' # path in the HDF5 file to images\n",
    "h5path_idx = 'INDEX/{}/DET/{}:xtdf/' # path in the HDF5 file to images\n",
    "h5path_ctrl = '/CONTROL/{}/MDL/FPGA_COMP' # path to control information\n",
    "karabo_id_control = \"HED_EXP_AGIPD500K2G\" # karabo-id for control device\n",
    "karabo_da_control = 'AGIPD500K2G00' # karabo DA for control infromation\n",
    "\n",
    "slopes_ff_from_files = \"\" # Path to locally stored SlopesFF and BadPixelsFF constants\n",
    "\n",
    "use_dir_creation_date = True # use the creation data of the input dir for database queries\n",
    "cal_db_interface = \"tcp://max-exfl016:8015#8045\" # the database interface to use\n",
    "cal_db_timeout = 30000 # in milli seconds\n",
    "creation_date_offset = \"00:00:00\" # add an offset to creation date, e.g. to get different constants\n",
    "\n",
    "max_cells = 0 # number of memory cells used, set to 0 to automatically infer\n",
    "bias_voltage = 300 # Bias voltage\n",
    "acq_rate = 0. # the detector acquisition rate, use 0 to try to auto-determine\n",
    "gain_setting = 0.1 # the gain setting, use 0.1 to try to auto-determine\n",
    "photon_energy = 9.2 # photon energy in keV\n",
    "overwrite = True # set to True if existing data should be overwritten\n",
    "max_pulses = [0, 500, 1] # range list [st, end, step] of maximum pulse indices within a train. 3 allowed maximum list input elements.   \n",
    "mem_cells_db = 0 # set to a value different than 0 to use this value for DB queries\n",
    "cell_id_preview = 1 # cell Id used for preview in single-shot plots\n",
    "\n",
    "# Correction parameters\n",
    "blc_noise_threshold = 5000 # above this mean signal intensity now baseline correction via noise is attempted\n",
    "cm_dark_fraction = 0.66 # threshold for fraction of  empty pixels to consider module enough dark to perform CM correction\n",
    "cm_dark_range = [-50.,30] # range for signal value ADU for pixel to be consider as a dark pixel\n",
    "cm_n_itr = 4 # number of iterations for common mode correction\n",
    "hg_hard_threshold = 1000 # threshold to force medium gain offset subtracted pixel to high gain\n",
    "mg_hard_threshold = 1000 # threshold to force medium gain offset subtracted pixel from low to medium gain\n",
    "noisy_adc_threshold = 0.25 # threshold to mask complete adc\n",
    "ff_gain = 7.2 # conversion gain for absolute FlatField constants, while applying xray_gain\n",
    "\n",
    "# Correction Booleans\n",
    "only_offset = False # Apply only Offset correction. if False, Offset is applied by Default. if True, Offset is only applied.\n",
    "rel_gain = False # do relative gain correction based on PC data\n",
    "xray_gain = False # do relative gain correction based on xray data\n",
    "blc_noise = False # if set, baseline correction via noise peak location is attempted\n",
    "blc_stripes = False # if set, baseline corrected via stripes\n",
    "blc_hmatch = False # if set, base line correction via histogram matching is attempted \n",
    "match_asics = False # if set, inner ASIC borders are matched to the same signal level\n",
    "adjust_mg_baseline = False # adjust medium gain baseline to match highest high gain value\n",
    "zero_nans = False # set NaN values in corrected data to 0\n",
    "zero_orange = False # set to 0 very negative and very large values in corrected data\n",
    "blc_set_min = False # Shift to 0 negative medium gain pixels after offset corr\n",
    "corr_asic_diag = False # if set, diagonal drop offs on ASICs are correted\n",
    "force_hg_if_below = False # set high gain if mg offset subtracted value is below hg_hard_threshold\n",
    "force_mg_if_below = False # set medium gain if mg offset subtracted value is below mg_hard_threshold\n",
    "mask_noisy_adc = False # Mask entire ADC if they are noise above a relative threshold\n",
    "common_mode = False # Common mode correction\n",
    "melt_snow = False # Identify (and optionally interpolate) 'snowy' pixels\n",
    "mask_zero_std = False # Mask pixels with zero standard deviation across train\n",
    "low_medium_gap = False # 5 sigma separation in thresholding between low and medium gain\n",
    "\n",
    "# Paralellization parameters\n",
    "chunk_size = 1000 # Size of chunk for image-weise correction\n",
    "chunk_size_idim = 1  # chunking size of imaging dimension, adjust if user software is sensitive to this.\n",
    "n_cores_correct = 16 # Number of chunks to be processed in parallel\n",
    "n_cores_files = 4 # Number of files to be processed in parallel\n",
    "sequences_per_node = 2 # number of sequence files per cluster node if run as slurm job, set to 0 to not run SLURM parallel\n",
    "\n",
    "def balance_sequences(in_folder, run, sequences, sequences_per_node, karabo_da):\n",
    "    from xfel_calibrate.calibrate import balance_sequences as bs\n",
    "    return bs(in_folder, run, sequences, sequences_per_node, karabo_da)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "from datetime import timedelta\n",
    "from dateutil import parser\n",
    "import gc\n",
    "import glob\n",
    "import itertools\n",
    "from IPython.display import HTML, display, Markdown, Latex\n",
    "import math\n",
    "from multiprocessing import Pool\n",
    "import os\n",
    "import re\n",
    "import sys\n",
    "import traceback\n",
    "from time import time, sleep, perf_counter\n",
    "import tabulate\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "import yaml\n",
    "\n",
    "from extra_geom import AGIPD_1MGeometry, AGIPD_500K2GGeometry\n",
    "from extra_data import RunDirectory, stack_detector_data\n",
    "from iCalibrationDB import Detectors\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from matplotlib.ticker import LinearLocator, FormatStrFormatter\n",
    "from matplotlib.colors import LogNorm\n",
    "from matplotlib import cm as colormap\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "matplotlib.use(\"agg\")\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "sns.set()\n",
    "sns.set_context(\"paper\", font_scale=1.4)\n",
    "sns.set_style(\"ticks\")\n",
    "\n",
    "from cal_tools.agipdlib import (AgipdCorrections, get_acq_rate,\n",
    "                                get_gain_setting, get_num_cells)\n",
    "from cal_tools.cython import agipdalgs as calgs\n",
    "from cal_tools.ana_tools import get_range\n",
    "from cal_tools.enums import BadPixels\n",
    "from cal_tools.tools import get_dir_creation_date, map_modules_from_folder\n",
    "from cal_tools.step_timing import StepTimer\n",
    "\n",
    "import seaborn as sns\n",
    "sns.set()\n",
    "sns.set_context(\"paper\", font_scale=1.4)\n",
    "sns.set_style(\"ticks\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluated parameters ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fill dictionaries comprising bools and arguments for correction and data analysis\n",
    "\n",
    "# Here the herarichy and dependability for correction booleans are defined\n",
    "corr_bools = {}\n",
    "\n",
    "# offset is at the bottom of AGIPD correction pyramid.\n",
    "corr_bools[\"only_offset\"] = only_offset\n",
    "\n",
    "# Dont apply any corrections if only_offset is requested\n",
    "if not only_offset:\n",
    "    corr_bools[\"adjust_mg_baseline\"] = adjust_mg_baseline\n",
    "    corr_bools[\"rel_gain\"] = rel_gain\n",
    "    corr_bools[\"xray_corr\"] = xray_gain\n",
    "    corr_bools[\"blc_noise\"] = blc_noise\n",
    "    corr_bools[\"blc_stripes\"] = blc_stripes\n",
    "    corr_bools[\"blc_hmatch\"] = blc_hmatch\n",
    "    corr_bools[\"blc_set_min\"] = blc_set_min\n",
    "    corr_bools[\"match_asics\"] = match_asics\n",
    "    corr_bools[\"corr_asic_diag\"] = corr_asic_diag\n",
    "    corr_bools[\"zero_nans\"] = zero_nans\n",
    "    corr_bools[\"zero_orange\"] = zero_orange\n",
    "    corr_bools[\"mask_noisy_adc\"] = mask_noisy_adc\n",
    "    corr_bools[\"force_hg_if_below\"] = force_hg_if_below\n",
    "    corr_bools[\"force_mg_if_below\"] = force_mg_if_below\n",
    "    corr_bools[\"common_mode\"] = common_mode\n",
    "    corr_bools[\"melt_snow\"] = melt_snow\n",
    "    corr_bools[\"mask_zero_std\"] = mask_zero_std\n",
    "    corr_bools[\"low_medium_gap\"] = low_medium_gap \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if in_folder[-1] == \"/\":\n",
    "    in_folder = in_folder[:-1]\n",
    "if sequences[0] == -1:\n",
    "    sequences = None\n",
    "\n",
    "control_fname = f'{in_folder}/r{run:04d}/RAW-R{run:04d}-{karabo_da_control}-S00000.h5'\n",
    "h5path_ctrl = h5path_ctrl.format(karabo_id_control)\n",
    "h5path = h5path.format(karabo_id, receiver_id)\n",
    "h5path_idx = h5path_idx.format(karabo_id, receiver_id)\n",
    "\n",
    "print(f'Path to control file {control_fname}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-21T11:30:07.086286Z",
     "start_time": "2019-02-21T11:30:06.929722Z"
    }
   },
   "outputs": [],
   "source": [
    "# Create output folder\n",
    "os.makedirs(out_folder, exist_ok=overwrite)\n",
    "\n",
    "# Evaluate detector instance for mapping\n",
    "instrument = karabo_id.split(\"_\")[0]\n",
    "if instrument == \"SPB\":\n",
    "    dinstance = \"AGIPD1M1\"\n",
    "    nmods = 16\n",
    "elif instrument == \"MID\":\n",
    "    dinstance = \"AGIPD1M2\"\n",
    "    nmods = 16\n",
    "elif instrument == \"HED\":\n",
    "    dinstance = \"AGIPD500K\"\n",
    "    nmods = 8\n",
    "\n",
    "# Evaluate requested modules\n",
    "if karabo_da[0] == '-1':\n",
    "    if modules[0] == -1:\n",
    "        modules = list(range(nmods))\n",
    "    karabo_da = [\"AGIPD{:02d}\".format(i) for i in modules]\n",
    "else:\n",
    "    modules = [int(x[-2:]) for x in karabo_da]\n",
    "    \n",
    "def mod_name(modno):\n",
    "    return f\"Q{modno // 4 + 1}M{modno % 4 + 1}\"\n",
    "\n",
    "print(\"Process modules: \", ', '.join(\n",
    "    [mod_name(x) for x in modules]))\n",
    "print(f\"Detector in use is {karabo_id}\")\n",
    "print(f\"Instrument {instrument}\")\n",
    "print(f\"Detector instance {dinstance}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display Information about the selected pulses indices for correction.\n",
    "pulses_lst = list(range(*max_pulses)) if not (len(max_pulses)==1 and max_pulses[0]==0) else max_pulses  \n",
    "\n",
    "try:\n",
    "    if len(pulses_lst) > 1:        \n",
    "        print(\"A range of {} pulse indices is selected: from {} to {} with a step of {}\"\n",
    "               .format(len(pulses_lst), pulses_lst[0] , pulses_lst[-1] + (pulses_lst[1] - pulses_lst[0]),\n",
    "                       pulses_lst[1] - pulses_lst[0]))\n",
    "    else:\n",
    "        print(\"one pulse is selected: a pulse of idx {}\".format(pulses_lst[0]))\n",
    "except Exception as e:\n",
    "    raise ValueError('max_pulses input Error: {}'.format(e))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set everything up filewise\n",
    "mmf = map_modules_from_folder(in_folder, run, path_template,\n",
    "                              karabo_da, sequences)\n",
    "mapped_files, mod_ids, total_sequences, sequences_qm, _ = mmf\n",
    "file_list = []\n",
    "\n",
    "# ToDo: Split table over pages\n",
    "print(f\"Processing a total of {total_sequences} sequence files in chunks of {n_cores_files}\")\n",
    "table = []\n",
    "ti = 0\n",
    "for k, files in mapped_files.items():\n",
    "    i = 0\n",
    "    for f in list(files.queue):\n",
    "        file_list.append(f)\n",
    "        if i == 0:\n",
    "            table.append((ti, k, i, f))\n",
    "        else:\n",
    "            table.append((ti, \"\", i,  f))\n",
    "        i += 1\n",
    "        ti += 1\n",
    "md = display(Latex(tabulate.tabulate(table, tablefmt='latex',\n",
    "                                     headers=[\"#\", \"module\", \"# module\", \"file\"])))\n",
    "file_list = sorted(file_list, key=lambda name: name[-10:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filename = file_list[0]\n",
    "channel = int(re.findall(r\".*-AGIPD([0-9]+)-.*\", filename)[0])\n",
    "\n",
    "# Evaluate number of memory cells\n",
    "mem_cells = get_num_cells(filename, karabo_id, channel)\n",
    "if mem_cells is None:\n",
    "    raise ValueError(f\"No raw images found in {filename}\")\n",
    "\n",
    "mem_cells_db = mem_cells if mem_cells_db == 0 else mem_cells_db\n",
    "max_cells = mem_cells if max_cells == 0 else max_cells\n",
    "\n",
    "# Evaluate aquisition rate\n",
    "if acq_rate == 0:\n",
    "    acq_rate = get_acq_rate((filename, karabo_id, channel))\n",
    "\n",
    "print(f\"Maximum memory cells to calibrate: {max_cells}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate creation time\n",
    "creation_time = None\n",
    "if use_dir_creation_date:\n",
    "    creation_time = get_dir_creation_date(in_folder, run)\n",
    "    offset = parser.parse(creation_date_offset)\n",
    "    delta = timedelta(hours=offset.hour,\n",
    "                      minutes=offset.minute, seconds=offset.second)\n",
    "    creation_time += delta\n",
    "    \n",
    "# Evaluate gain setting\n",
    "if gain_setting == 0.1:\n",
    "    if creation_time.replace(tzinfo=None) < parser.parse('2020-01-31'):\n",
    "        print(\"Set gain-setting to None for runs taken before 2020-01-31\")\n",
    "        gain_setting = None\n",
    "    else:\n",
    "        try:\n",
    "            gain_setting = get_gain_setting(control_fname, h5path_ctrl)\n",
    "        except Exception as e:\n",
    "            print(f'ERROR: while reading gain setting from: \\n{control_fname}')\n",
    "            print(e)\n",
    "            print(\"Set gain setting to 0\")\n",
    "            gain_setting = 0\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Using {creation_time} as creation time\")\n",
    "print(f\"Operating conditions are:\\n• Bias voltage: {bias_voltage}\\n• Memory cells: {mem_cells_db}\\n\"\n",
    "              f\"• Acquisition rate: {acq_rate}\\n• Gain setting: {gain_setting}\\n• Photon Energy: {photon_energy}\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data processing ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agipd_corr = AgipdCorrections(max_cells, max_pulses,\n",
    "                              h5_data_path=h5path,\n",
    "                              h5_index_path=h5path_idx,\n",
    "                              corr_bools=corr_bools)\n",
    "\n",
    "agipd_corr.baseline_corr_noise_threshold = -blc_noise_threshold\n",
    "agipd_corr.hg_hard_threshold = hg_hard_threshold\n",
    "agipd_corr.mg_hard_threshold = mg_hard_threshold\n",
    "\n",
    "agipd_corr.cm_dark_min = cm_dark_range[0]\n",
    "agipd_corr.cm_dark_max = cm_dark_range[1]\n",
    "agipd_corr.cm_dark_fraction = cm_dark_fraction\n",
    "agipd_corr.cm_n_itr = cm_n_itr\n",
    "agipd_corr.noisy_adc_threshold = noisy_adc_threshold\n",
    "agipd_corr.ff_gain = ff_gain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retrieve calibration constants to RAM\n",
    "agipd_corr.allocate_constants(modules, (3, mem_cells_db, 512, 128))\n",
    "\n",
    "const_yaml = None\n",
    "if os.path.isfile(f'{out_folder}/retrieved_constants.yml'):\n",
    "    with open(f'{out_folder}/retrieved_constants.yml', \"r\") as f:\n",
    "        const_yaml = yaml.safe_load(f.read())\n",
    "\n",
    "# retrive constants\n",
    "def retrieve_constants(mod):\n",
    "    \"\"\"\n",
    "    Retrieve calibration constants and load them to shared memory\n",
    "    \n",
    "    Metadata for constants is taken from yml file or retrieved from the DB\n",
    "    \"\"\"\n",
    "    device = getattr(getattr(Detectors, dinstance), mod_name(mod))\n",
    "    err = ''\n",
    "    try:\n",
    "        # check if there is a yaml file in out_folder that has the device constants.\n",
    "        if const_yaml and device.device_name in const_yaml:\n",
    "            when = agipd_corr.initialize_from_yaml(const_yaml, mod, device)\n",
    "        else:\n",
    "            when = agipd_corr.initialize_from_db(cal_db_interface, creation_time, mem_cells_db, bias_voltage,\n",
    "                                                 photon_energy, gain_setting, acq_rate, mod, device, False)\n",
    "    except Exception as e:\n",
    "        err = f\"Error: {e}\\nError traceback: {traceback.format_exc()}\"\n",
    "        when = None\n",
    "    return err, mod, when, device.device_name\n",
    "\n",
    "\n",
    "ts = perf_counter()\n",
    "with Pool(processes=len(modules)) as pool:\n",
    "    const_out = pool.map(retrieve_constants, modules)\n",
    "print(f\"Constants were loaded in {perf_counter()-ts:.01f}s\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# allocate memory for images and hists\n",
    "n_images_max = max_cells*256\n",
    "data_shape = (n_images_max, 512, 128)\n",
    "agipd_corr.allocate_images(data_shape, n_cores_files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def batches(l, batch_size):\n",
    "    \"\"\"Group a list into batches of (up to) batch_size elements\"\"\"\n",
    "    start = 0\n",
    "    while start < len(l):\n",
    "        yield l[start:start + batch_size]\n",
    "        start += batch_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def imagewise_chunks(img_counts):\n",
    "    \"\"\"Break up the loaded data into chunks of up to chunk_size\n",
    "    \n",
    "    Yields (file data slot, start index, stop index)\n",
    "    \"\"\"\n",
    "    for i_proc, n_img in enumerate(img_counts):\n",
    "        n_chunks = math.ceil(n_img / chunk_size)\n",
    "        for i in range(n_chunks):\n",
    "            yield i_proc, i * n_img // n_chunks, (i+1) * n_img // n_chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "step_timer = StepTimer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "with Pool() as pool:\n",
    "    for file_batch in batches(file_list, n_cores_files):\n",
    "        # TODO: Move some printed output to logging or similar\n",
    "        print(f\"Processing next {len(file_batch)} files:\")\n",
    "        for file_name in file_batch:\n",
    "            print(\" \", file_name)\n",
    "        step_timer.start()\n",
    "        \n",
    "        img_counts = pool.starmap(agipd_corr.read_file, enumerate(file_batch))\n",
    "        step_timer.done_step('Loading data from files')\n",
    "        \n",
    "        # Evaluate zero-data-std mask\n",
    "        pool.starmap(agipd_corr.mask_zero_std, itertools.product(\n",
    "            range(len(file_batch)), np.array_split(np.arange(agipd_corr.max_cells), n_cores_correct)\n",
    "        ))\n",
    "        step_timer.done_step('Mask 0 std')\n",
    "\n",
    "        # Perform offset image-wise correction\n",
    "        pool.starmap(agipd_corr.offset_correction, imagewise_chunks(img_counts))\n",
    "        step_timer.done_step(\"Offset correction\")\n",
    "        \n",
    "        \n",
    "        if blc_noise or blc_stripes or blc_hmatch:\n",
    "            # Perform image-wise correction\n",
    "            pool.starmap(agipd_corr.baseline_correction, imagewise_chunks(img_counts))\n",
    "            step_timer.done_step(\"Base-line shift correction\")\n",
    "        \n",
    "        if common_mode:\n",
    "            # Perform cross-file correction parallel over asics\n",
    "            pool.starmap(agipd_corr.cm_correction, itertools.product(\n",
    "                range(len(file_batch)), range(16)  # 16 ASICs per module\n",
    "            ))\n",
    "            step_timer.done_step(\"Common-mode correction\")\n",
    "        \n",
    "        # Perform image-wise correction\n",
    "        pool.starmap(agipd_corr.gain_correction, imagewise_chunks(img_counts))\n",
    "        step_timer.done_step(\"Image-wise correction\")\n",
    "        \n",
    "        # Save corrected data\n",
    "        pool.starmap(agipd_corr.write_file, [\n",
    "            (i_proc, file_name, os.path.join(out_folder, os.path.basename(file_name).replace(\"RAW\", \"CORR\")))\n",
    "            for i_proc, file_name in enumerate(file_batch)\n",
    "        ])\n",
    "        step_timer.done_step(\"Save\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Correction of {len(file_list)} files is finished\")\n",
    "print(f\"Total processing time {step_timer.timespan():.01f} s\")\n",
    "print(f\"Timing summary per batch of {n_cores_files} files:\")\n",
    "step_timer.print_summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# if there is a yml file that means a leading notebook got processed\n",
    "# and the reporting would be generated from it.\n",
    "fst_print = True\n",
    "\n",
    "to_store = []\n",
    "line = []\n",
    "for i, (error, modno, when, mod_dev) in enumerate(const_out):\n",
    "    qm = mod_name(modno)\n",
    "    # expose errors while applying correction\n",
    "    if error:\n",
    "        print(\"Error: {}\".format(error) )\n",
    "\n",
    "    if not const_yaml or mod_dev not in const_yaml:\n",
    "        if fst_print:\n",
    "            print(\"Constants are retrieved with creation time: \")\n",
    "            fst_print = False\n",
    "    \n",
    "        line = [qm]\n",
    "\n",
    "        # If correction is crashed\n",
    "        if not error:\n",
    "            print(f\"{qm}:\")\n",
    "            for key, item in when.items():\n",
    "                if hasattr(item, 'strftime'):\n",
    "                    item = item.strftime('%y-%m-%d %H:%M')\n",
    "                when[key] = item\n",
    "                print('{:.<12s}'.format(key), item)\n",
    "\n",
    "        # Store few time stamps if exists\n",
    "        # Add NA to keep array structure\n",
    "        for key in ['Offset', 'SlopesPC', 'SlopesFF']:\n",
    "            if when and key in when and when[key]:\n",
    "                line.append(when[key])\n",
    "            else:\n",
    "                if error is not None:\n",
    "                    line.append('Err')\n",
    "                else:\n",
    "                    line.append('NA')\n",
    "\n",
    "        if len(line) > 0:\n",
    "            to_store.append(line)\n",
    "\n",
    "seq = sequences[0] if sequences else 0\n",
    "\n",
    "if len(to_store) > 0:\n",
    "    with open(f\"{out_folder}/retrieved_constants_s{seq}.yml\",\"w\") as fyml:\n",
    "        yaml.safe_dump({\"time-summary\": {f\"S{seq}\":to_store}}, fyml)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:28:51.765030Z",
     "start_time": "2019-02-18T17:28:51.714783Z"
    }
   },
   "outputs": [],
   "source": [
    "def do_3d_plot(data, edges, x_axis, y_axis):\n",
    "    fig = plt.figure(figsize=(10, 10))\n",
    "    ax = fig.gca(projection='3d')\n",
    "\n",
    "    # Make data.\n",
    "    X = edges[0][:-1]\n",
    "    Y = edges[1][:-1]\n",
    "    X, Y = np.meshgrid(X, Y)\n",
    "    Z = data.T\n",
    "\n",
    "    # Plot the surface.\n",
    "    surf = ax.plot_surface(X, Y, Z, cmap=colormap.coolwarm,\n",
    "                           linewidth=0, antialiased=False)\n",
    "    ax.set_xlabel(x_axis)\n",
    "    ax.set_ylabel(y_axis)\n",
    "    ax.set_zlabel(\"Counts\")\n",
    "\n",
    "\n",
    "def do_2d_plot(data, edges, y_axis, x_axis):\n",
    "    fig = plt.figure(figsize=(10, 10))\n",
    "    ax = fig.add_subplot(111)\n",
    "    extent = [np.min(edges[1]), np.max(edges[1]),\n",
    "              np.min(edges[0]), np.max(edges[0])]\n",
    "    im = ax.imshow(data[::-1, :], extent=extent, aspect=\"auto\",\n",
    "                   norm=LogNorm(vmin=1, vmax=max(10, np.max(data))))\n",
    "    ax.set_xlabel(x_axis)\n",
    "    ax.set_ylabel(y_axis)\n",
    "    cb = fig.colorbar(im)\n",
    "    cb.set_label(\"Counts\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_trains_data(run_folder, source, include, tid=None, path='*/DET/*', modules=16, fillvalue=np.nan):\n",
    "    \"\"\"\n",
    "    Load single train for all module\n",
    "    \n",
    "    :param run_folder: Path to folder with data\n",
    "    :param source: Data source to be loaded\n",
    "    :param include: Inset of file name to be considered \n",
    "    :param tid: Train Id to be loaded. First train is considered if None is given\n",
    "    :param path: Path to find image data inside h5 file\n",
    "    \n",
    "    \"\"\"\n",
    "    run_data = RunDirectory(run_folder, include)\n",
    "    if tid:\n",
    "        tid, data = run_data.select('*/DET/*', source).train_from_id(tid)\n",
    "        return tid, stack_detector_data(train=data, data=source, fillvalue=fillvalue, modules=modules)\n",
    "    else:\n",
    "        for tid, data in run_data.select('*/DET/*', source).trains(require_all=True):\n",
    "            return tid, stack_detector_data(train=data, data=source, fillvalue=fillvalue, modules=modules)\n",
    "    return None, None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if dinstance == \"AGIPD500K\":\n",
    "    geom = AGIPD_500K2GGeometry.from_origin()\n",
    "else:\n",
    "    geom = AGIPD_1MGeometry.from_quad_positions(quad_pos=[\n",
    "        (-525, 625),\n",
    "        (-550, -10),\n",
    "        (520, -160),\n",
    "        (542.5, 475),\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "include = '*S00000*' if sequences is None else f'*S{sequences[0]:05d}*'\n",
    "tid, corrected = get_trains_data(f'{out_folder}/', 'image.data', include, modules=nmods)\n",
    "_, gains = get_trains_data(f'{out_folder}/', 'image.gain', include, tid, modules=nmods)\n",
    "_, mask = get_trains_data(f'{out_folder}/', 'image.mask', include, tid, modules=nmods)\n",
    "_, blshift = get_trains_data(f'{out_folder}/', 'image.blShift', include, tid, modules=nmods)\n",
    "_, cellId = get_trains_data(f'{out_folder}/', 'image.cellId', include, tid, modules=nmods)\n",
    "_, pulseId = get_trains_data(f'{out_folder}/', 'image.pulseId', include, tid,\n",
    "                             modules=nmods, fillvalue=0)\n",
    "_, raw = get_trains_data(f'{in_folder}/r{run:04d}/', 'image.data', include, tid, modules=nmods)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(f'## Preview and statistics for {gains.shape[0]} images of the train {tid} ##\\n'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Signal vs. Analogue Gain ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hist, bins_x, bins_y = calgs.histogram2d(raw[:,0,...].flatten().astype(np.float32),\n",
    "                                         raw[:,1,...].flatten().astype(np.float32),\n",
    "                                         bins=(100, 100),\n",
    "                                         range=[[4000, 8192], [4000, 8192]])\n",
    "do_2d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Analogue gain (ADU)\")\n",
    "do_3d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Analogue gain (ADU)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Signal vs. Digitized Gain ###\n",
    "\n",
    "The following plot shows plots signal vs. digitized gain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hist, bins_x, bins_y = calgs.histogram2d(corrected.flatten().astype(np.float32),\n",
    "                                         gains.flatten().astype(np.float32), bins=(100, 3),\n",
    "                                         range=[[-50, 8192], [0, 3]])\n",
    "do_2d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Gain bit value\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Gain statistics in %\")\n",
    "table = [[f'{gains[gains==0].size/gains.size*100:.02f}',\n",
    "          f'{gains[gains==1].size/gains.size*100:.03f}',\n",
    "          f'{gains[gains==2].size/gains.size*100:.03f}']] \n",
    "md = display(Latex(tabulate.tabulate(table, tablefmt='latex',\n",
    "                                     headers=[\"High\", \"Medium\", \"Low\"])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Intensity per Pulse ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "pulse_range = [np.min(pulseId[pulseId>=0]), np.max(pulseId[pulseId>=0])]\n",
    "\n",
    "mean_data = np.nanmean(corrected, axis=(2, 3))\n",
    "hist, bins_x, bins_y = calgs.histogram2d(mean_data.flatten().astype(np.float32),\n",
    "                                      pulseId.flatten().astype(np.float32),\n",
    "                                      bins=(100, int(pulse_range[1])),\n",
    "                                      range=[[-50, 1000], pulse_range])\n",
    "\n",
    "do_2d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Pulse id\")\n",
    "do_3d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Pulse id\")\n",
    "\n",
    "hist, bins_x, bins_y = calgs.histogram2d(mean_data.flatten().astype(np.float32),\n",
    "                                      pulseId.flatten().astype(np.float32),\n",
    "                                      bins=(100,  int(pulse_range[1])),\n",
    "                                      range=[[-50, 200000], pulse_range])\n",
    "\n",
    "do_2d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Pulse id\")\n",
    "do_3d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Pulse id\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Baseline shift ###\n",
    "\n",
    "Estimated base-line shift with respect to the total ADU counts of corrected image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "h = ax.hist(blshift.flatten(), bins=100, log=True)\n",
    "_ = plt.xlabel('Baseline shift [ADU]')\n",
    "_ = plt.ylabel('Counts')\n",
    "_ = ax.grid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 10))\n",
    "corrected_ave = np.nansum(corrected, axis=(2, 3))\n",
    "plt.scatter(corrected_ave.flatten()/10**6, blshift.flatten(), s=0.9)\n",
    "plt.xlim(-1, 1000)\n",
    "plt.grid()\n",
    "plt.xlabel('Illuminated corrected [MADU] ')\n",
    "_ = plt.ylabel('Estimated baseline shift [ADU]')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown('### Raw preview ###\\n'))\n",
    "display(Markdown(f'Mean over images of the RAW data\\n'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:33.226396Z",
     "start_time": "2019-02-18T17:29:27.027758Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "data = np.mean(raw[:, 0, ...], axis=0)\n",
    "vmin, vmax = get_range(data, 5)\n",
    "ax = geom.plot_data_fast(data, ax=ax, cmap=\"jet\", vmin=vmin, vmax=vmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(f'Single shot of the RAW data from cell {np.max(cellId[cell_id_preview])} \\n'))\n",
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "vmin, vmax = get_range(raw[cell_id_preview, 0, ...], 5)\n",
    "ax = geom.plot_data_fast(raw[cell_id_preview, 0, ...], ax=ax, cmap=\"jet\", vmin=vmin, vmax=vmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown('### Corrected preview ###\\n'))\n",
    "display(Markdown(f'A single shot image from cell {np.max(cellId[cell_id_preview])} \\n'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:33.761015Z",
     "start_time": "2019-02-18T17:29:33.227922Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "vmin, vmax = get_range(corrected[cell_id_preview], 7, -50)\n",
    "vmin = - 50\n",
    "ax = geom.plot_data_fast(corrected[cell_id_preview], ax=ax, cmap=\"jet\", vmin=vmin, vmax=vmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:35.903487Z",
     "start_time": "2019-02-18T17:29:33.762568Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "vmin, vmax = get_range(corrected[cell_id_preview], 5, -50)\n",
    "nbins = np.int((vmax + 50) / 2)\n",
    "h = ax.hist(corrected[cell_id_preview].flatten(), \n",
    "            bins=nbins, range=(-50, vmax), \n",
    "            histtype='stepfilled', log=True)\n",
    "_ = plt.xlabel('[ADU]')\n",
    "_ = plt.ylabel('Counts')\n",
    "_ = ax.grid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown('### Mean CORRECTED Preview ###\\n'))\n",
    "display(Markdown(f'A mean across one train \\n'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:39.369686Z",
     "start_time": "2019-02-18T17:29:35.905152Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "data = np.mean(corrected, axis=0)\n",
    "vmin, vmax = get_range(data, 7)\n",
    "ax = geom.plot_data_fast(data, ax=ax, cmap=\"jet\", vmin=-50, vmax=vmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:49.217848Z",
     "start_time": "2019-02-18T17:29:39.371232Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "vmin, vmax = get_range(corrected, 10, -100)\n",
    "vmax = np.nanmax(corrected)\n",
    "if vmax > 50000:\n",
    "    vmax=50000\n",
    "nbins = np.int((vmax + 100) / 5)\n",
    "h = ax.hist(corrected.flatten(), bins=nbins,\n",
    "            range=(-100, vmax), histtype='step', log=True, label = 'All')\n",
    "_ = ax.hist(corrected[gains == 0].flatten(), bins=nbins, range=(-100, vmax),\n",
    "            alpha=0.5, log=True, label='High gain', color='green')\n",
    "_ = ax.hist(corrected[gains == 1].flatten(), bins=nbins, range=(-100, vmax),\n",
    "            alpha=0.5, log=True, label='Medium gain', color='red')\n",
    "_ = ax.hist(corrected[gains == 2].flatten(), bins=nbins,\n",
    "            range=(-100, vmax), alpha=0.5, log=True, label='Low gain', color='yellow')\n",
    "_ = ax.legend()\n",
    "_ = ax.grid()\n",
    "_ = plt.xlabel('[ADU]')\n",
    "_ = plt.ylabel('Counts')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown('### Maximum GAIN Preview ###\\n'))\n",
    "display(Markdown(f'The per pixel maximum across one train for the digitized gain'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:49.641675Z",
     "start_time": "2019-02-18T17:29:49.224167Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "ax = geom.plot_data_fast(np.max(gains, axis=0), ax=ax,\n",
    "                         cmap=\"jet\", vmin=-1, vmax=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Bad Pixels ##\n",
    "The mask contains dedicated entries for all pixels and memory cells as well as all three gains stages. Each mask entry is encoded in 32 bits as:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:49.651913Z",
     "start_time": "2019-02-18T17:29:49.643556Z"
    }
   },
   "outputs": [],
   "source": [
    "table = []\n",
    "for item in BadPixels:\n",
    "    table.append((item.name, \"{:016b}\".format(item.value)))\n",
    "md = display(Latex(tabulate.tabulate(table, tablefmt='latex',\n",
    "                                     headers=[\"Bad pixel type\", \"Bit mask\"])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(f'### Single Shot Bad Pixels ### \\n'))\n",
    "display(Markdown(f'A single shot bad pixel map from cell {np.max(cellId[cell_id_preview])} \\n'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:50.086169Z",
     "start_time": "2019-02-18T17:29:49.653391Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "ax = geom.plot_data_fast(np.log2(mask[cell_id_preview]), ax=ax, vmin=0, vmax=32, cmap=\"jet\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "### Percentage of Bad Pixels across one train  ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:51.686562Z",
     "start_time": "2019-02-18T17:29:50.088883Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "ax = geom.plot_data_fast(np.mean(mask>0, axis=0),\n",
    "                         vmin=0, ax=ax, vmax=1, cmap=\"jet\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Percentage of Bad Pixels across one train. Only Dark Related ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-18T17:29:55.483270Z",
     "start_time": "2019-02-18T17:29:53.664226Z"
    }
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "cm = np.copy(mask)\n",
    "cm[cm > BadPixels.NO_DARK_DATA.value] = 0\n",
    "ax = geom.plot_data_fast(np.mean(cm>0, axis=0),\n",
    "                         vmin=0, ax=ax, vmax=1, cmap=\"jet\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}