{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AGIPD Offline Correction #\n",
    "\n",
    "Author: European XFEL Detector Group, Version: 2.0\n",
    "\n",
    "Offline Calibration for the AGIPD Detector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_folder = \"/gpfs/exfel/exp/MID/202201/p002834/raw\" # the folder to read data from, required\n",
    "out_folder = \"/gpfs/exfel/data/scratch/esobolev/pycal_litfrm/p002834/r0225\"  # the folder to output to, required\n",
    "metadata_folder = \"\"  # Directory containing calibration_metadata.yml when run by xfel-calibrate\n",
    "sequences = [-1] # sequences to correct, set to -1 for all, range allowed\n",
    "overwrite = False  # IGNORED, NEEDED FOR COMPATIBILITY.\n",
    "modules = [-1] # modules to correct, set to -1 for all, range allowed\n",
    "train_ids = [-1] # train IDs to correct, set to -1 for all, range allowed\n",
    "run = 225 # runs to process, required\n",
    "\n",
    "karabo_id = \"MID_DET_AGIPD1M-1\" # karabo karabo_id\n",
    "karabo_da = ['-1']  # a list of data aggregators names, Default [-1] for selecting all data aggregators\n",
    "receiver_template = \"{}CH0\" # inset for receiver devices\n",
    "path_template = 'RAW-R{:04d}-{}-S{:05d}.h5' # the template to use to access data\n",
    "instrument_source_template = '{}/DET/{}:xtdf'  # path in the HDF5 file to images\n",
    "index_source_template = 'INDEX/{}/DET/{}:xtdf/'  # path in the HDF5 file to images\n",
    "ctrl_source_template = '{}/MDL/FPGA_COMP'  # path to control information\n",
    "karabo_id_control = \"MID_EXP_AGIPD1M1\" # karabo-id for control device\n",
    "\n",
    "slopes_ff_from_files = \"\" # Path to locally stored SlopesFF and BadPixelsFF constants, loaded in precorrection notebook\n",
    "\n",
    "creation_time = \"\"  # To overwrite the measured creation_time. Required Format: YYYY-MM-DD HR:MN:SC e.g. \"2022-06-28 13:00:00\"\n",
    "cal_db_interface = \"tcp://max-exfl016:8015#8045\" # the database interface to use\n",
    "cal_db_timeout = 30000 # in milliseconds\n",
    "creation_date_offset = \"00:00:00\" # add an offset to creation date, e.g. to get different constants\n",
    "\n",
    "mem_cells = -1  # Number of memory cells used, set to 0 to automatically infer\n",
    "bias_voltage = -1  # bias voltage, set to 0 to use stored value in slow data.\n",
    "acq_rate = -1. # the detector acquisition rate, use 0 to try to auto-determine\n",
    "gain_setting = -1  # the gain setting, use -1 to use value stored in slow data.\n",
    "gain_mode = -1  # gain mode (0: adaptive, 1-3 fixed high/med/low, -1: read from CONTROL data)\n",
    "max_pulses = [0, 352, 1] # range list [st, end, step] of memory cell indices to be processed within a train. 3 allowed maximum list input elements.\n",
    "mem_cells_db = -1  # set to a value different than 0 to use this value for DB queries\n",
    "integration_time = -1 # integration time, negative values for auto-detection.\n",
    "\n",
    "# Correction parameters\n",
    "blc_noise_threshold = 5000 # above this mean signal intensity now baseline correction via noise is attempted\n",
    "cm_dark_fraction = 0.66 # threshold for fraction of  empty pixels to consider module enough dark to perform CM correction\n",
    "cm_dark_range = [-50.,30] # range for signal value ADU for pixel to be consider as a dark pixel\n",
    "cm_n_itr = 4 # number of iterations for common mode correction\n",
    "hg_hard_threshold = 1000 # threshold to force medium gain offset subtracted pixel to high gain\n",
    "mg_hard_threshold = 1000 # threshold to force medium gain offset subtracted pixel from low to medium gain\n",
    "noisy_adc_threshold = 0.25 # threshold to mask complete adc\n",
    "ff_gain = 7.2 # conversion gain for absolute FlatField constants, while applying xray_gain\n",
    "photon_energy = -1.0 # photon energy in keV, non-positive value for XGM autodetection\n",
    "\n",
    "# Correction Booleans\n",
    "only_offset = False # Apply only Offset correction. if False, Offset is applied by Default. if True, Offset is only applied.\n",
    "rel_gain = False # do relative gain correction based on PC data\n",
    "xray_gain = False # do relative gain correction based on xray data\n",
    "blc_noise = False # if set, baseline correction via noise peak location is attempted\n",
    "blc_stripes = False # if set, baseline corrected via stripes\n",
    "blc_hmatch = False # if set, base line correction via histogram matching is attempted\n",
    "match_asics = False # if set, inner ASIC borders are matched to the same signal level\n",
    "adjust_mg_baseline = False # adjust medium gain baseline to match highest high gain value\n",
    "zero_nans = False # set NaN values in corrected data to 0\n",
    "zero_orange = False # set to 0 very negative and very large values in corrected data\n",
    "blc_set_min = False # Shift to 0 negative medium gain pixels after offset corr\n",
    "corr_asic_diag = False # if set, diagonal drop offs on ASICs are corrected\n",
    "force_hg_if_below = False # set high gain if mg offset subtracted value is below hg_hard_threshold\n",
    "force_mg_if_below = False # set medium gain if mg offset subtracted value is below mg_hard_threshold\n",
    "mask_noisy_adc = False # Mask entire ADC if they are noise above a relative threshold\n",
    "common_mode = False # Common mode correction\n",
    "melt_snow = False # Identify (and optionally interpolate) 'snowy' pixels\n",
    "mask_zero_std = False # Mask pixels with zero standard deviation across train\n",
    "low_medium_gap = False # 5 sigma separation in thresholding between low and medium gain\n",
    "round_photons = False  # Round to absolute number of photons, only use with gain corrections\n",
    "\n",
    "# Optional auxiliary devices\n",
    "use_ppu_device = ''  # Device ID for a pulse picker device to only process picked trains, empty string to disable\n",
    "ppu_train_offset = 0  # When using the pulse picker, offset between the PPU's sequence start and actually picked train\n",
    "\n",
    "use_litframe_finder = 'off' # Process only illuminated frames: 'off' - disable, 'device' - use online device data, 'offline' - use offline algorithm, 'auto' - choose online/offline source automatically (default)\n",
    "litframe_device_id = '' # Device ID for a lit frame finder device, empty string to auto detection\n",
    "energy_threshold = -1000 # The low limit for the energy (uJ) exposed by frames subject to processing. If -1000, selection by pulse energy is disabled\n",
    "use_super_selection = 'cm' # Make a common selection for entire run: 'off' - disable, 'final' - enable for final selection, 'cm' - enable only for common mode correction\n",
    "\n",
    "use_xgm_device = ''  # DoocsXGM device ID to obtain actual photon energy, operating condition else.\n",
    "\n",
    "# Output parameters\n",
    "recast_image_data = ''  # Cast data to a different dtype before saving\n",
    "compress_fields = ['gain', 'mask']  # Datasets in image group to compress.\n",
    "\n",
    "# Plotting parameters\n",
    "skip_plots = False # exit after writing corrected files and metadata\n",
    "cell_id_preview = 1 # cell Id used for preview in single-shot plots\n",
    "\n",
    "# Parallelization parameters\n",
    "chunk_size = 1000  # Size of chunk for image-wise correction\n",
    "n_cores_correct = 16 # Number of chunks to be processed in parallel\n",
    "n_cores_files = 4 # Number of files to be processed in parallel\n",
    "sequences_per_node = 2 # number of sequence files per cluster node if run as SLURM job, set to 0 to not run SLURM parallel\n",
    "max_nodes = 8 # Maximum number of SLURM jobs to split correction work into\n",
    "max_tasks_per_worker = 1  # the number of tasks a correction pool worker process can complete before it will exit and be replaced with a fresh worker process. Leave as -1 to keep worker alive as long as pool.\n",
    "\n",
    "def balance_sequences(in_folder, run, sequences, sequences_per_node, karabo_da, max_nodes):\n",
    "    from xfel_calibrate.calibrate import balance_sequences as bs\n",
    "    return bs(in_folder, run, sequences, sequences_per_node, karabo_da, max_nodes=max_nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import os\n",
    "import math\n",
    "import multiprocessing\n",
    "import re\n",
    "import warnings\n",
    "from datetime import timedelta\n",
    "from logging import warning\n",
    "from pathlib import Path\n",
    "from time import perf_counter\n",
    "\n",
    "import tabulate\n",
    "from dateutil import parser\n",
    "from IPython.display import Latex, Markdown, display\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import yaml\n",
    "from extra_data import by_id, RunDirectory, stack_detector_data\n",
    "from extra_geom import AGIPD_1MGeometry, AGIPD_500K2GGeometry\n",
    "from matplotlib import cm as colormap\n",
    "from matplotlib.colors import LogNorm\n",
    "\n",
    "matplotlib.use(\"agg\")\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "\n",
    "sns.set()\n",
    "sns.set_context(\"paper\", font_scale=1.4)\n",
    "sns.set_style(\"ticks\")\n",
    "\n",
    "from cal_tools import agipdalgs as calgs\n",
    "from cal_tools.agipdlib import (\n",
    "    AgipdCorrections,\n",
    "    AgipdCtrl,\n",
    "    CellRange,\n",
    "    LitFrameSelection,\n",
    ")\n",
    "from cal_tools.ana_tools import get_range\n",
    "from cal_tools.enums import AgipdGainMode, BadPixels\n",
    "from cal_tools.step_timing import StepTimer\n",
    "from cal_tools.tools import (\n",
    "    CalibrationMetadata,\n",
    "    calcat_creation_time,\n",
    "    map_modules_from_folder,\n",
    "    module_index_to_qm,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_folder = Path(in_folder)\n",
    "out_folder = Path(out_folder)\n",
    "run_folder = in_folder / f'r{run:04d}'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluated parameters ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fill dictionaries comprising bools and arguments for correction and data analysis\n",
    "\n",
    "# Here the hierarchy and dependability for correction booleans are defined\n",
    "corr_bools = {}\n",
    "\n",
    "# offset is at the bottom of AGIPD correction pyramid.\n",
    "corr_bools[\"only_offset\"] = only_offset\n",
    "\n",
    "# Dont apply any corrections if only_offset is requested\n",
    "if not only_offset:\n",
    "    corr_bools[\"adjust_mg_baseline\"] = adjust_mg_baseline\n",
    "    corr_bools[\"rel_gain\"] = rel_gain\n",
    "    corr_bools[\"xray_corr\"] = xray_gain\n",
    "    corr_bools[\"blc_noise\"] = blc_noise\n",
    "    corr_bools[\"blc_stripes\"] = blc_stripes\n",
    "    corr_bools[\"blc_hmatch\"] = blc_hmatch\n",
    "    corr_bools[\"blc_set_min\"] = blc_set_min\n",
    "    corr_bools[\"match_asics\"] = match_asics\n",
    "    corr_bools[\"corr_asic_diag\"] = corr_asic_diag\n",
    "    corr_bools[\"zero_nans\"] = zero_nans\n",
    "    corr_bools[\"zero_orange\"] = zero_orange\n",
    "    corr_bools[\"mask_noisy_adc\"] = mask_noisy_adc\n",
    "    corr_bools[\"force_hg_if_below\"] = force_hg_if_below\n",
    "    corr_bools[\"force_mg_if_below\"] = force_mg_if_below\n",
    "    corr_bools[\"common_mode\"] = common_mode\n",
    "    corr_bools[\"melt_snow\"] = melt_snow\n",
    "    corr_bools[\"mask_zero_std\"] = mask_zero_std\n",
    "    corr_bools[\"low_medium_gap\"] = low_medium_gap\n",
    "    corr_bools[\"round_photons\"] = round_photons\n",
    "\n",
    "# Many corrections don't apply to fixed gain mode; will explicitly disable later if detected\n",
    "disable_for_fixed_gain = [\n",
    "    \"adjust_mg_baseline\",\n",
    "    \"blc_set_min\",\n",
    "    \"force_hg_if_below\",\n",
    "    \"force_mg_if_below\",\n",
    "    \"low_medium_gap\",\n",
    "    \"melt_snow\",\n",
    "    \"rel_gain\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if sequences == [-1]:\n",
    "    sequences = None\n",
    "\n",
    "dc = RunDirectory(run_folder)\n",
    "\n",
    "ctrl_src = ctrl_source_template.format(karabo_id_control)\n",
    "instrument_src = instrument_source_template.format(karabo_id, receiver_template)\n",
    "index_src = index_source_template.format(karabo_id, receiver_template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create output folder\n",
    "out_folder.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "# Evaluate detector instance for mapping\n",
    "instrument = karabo_id.split(\"_\")[0]\n",
    "if instrument == \"SPB\":\n",
    "    dinstance = \"AGIPD1M1\"\n",
    "    nmods = 16\n",
    "elif instrument == \"MID\":\n",
    "    dinstance = \"AGIPD1M2\"\n",
    "    nmods = 16\n",
    "elif instrument == \"HED\":\n",
    "    dinstance = \"AGIPD500K\"\n",
    "    nmods = 8\n",
    "\n",
    "# Evaluate requested modules\n",
    "if karabo_da[0] == '-1':\n",
    "    if modules[0] == -1:\n",
    "        modules = list(range(nmods))\n",
    "    karabo_da = [\"AGIPD{:02d}\".format(i) for i in modules]\n",
    "else:\n",
    "    modules = [int(x[-2:]) for x in karabo_da]\n",
    "\n",
    "print(\"Process modules:\", ', '.join(module_index_to_qm(x) for x in modules))\n",
    "print(f\"Detector in use is {karabo_id}\")\n",
    "print(f\"Instrument {instrument}\")\n",
    "print(f\"Detector instance {dinstance}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if use_ppu_device:\n",
    "    # Obtain trains to process if using a pulse picker device.\n",
    "\n",
    "    # Will throw an uncaught exception if the device is wrong.\n",
    "    seq_start = dc[use_ppu_device, 'trainTrigger.sequenceStart.value'].ndarray()\n",
    "\n",
    "    # The trains picked are the unique values of trainTrigger.sequenceStart\n",
    "    # minus the first (previous trigger before this run).\n",
    "    start_train_ids = np.unique(seq_start)[1:] + ppu_train_offset\n",
    "\n",
    "    train_ids = []\n",
    "    for train_id in start_train_ids:\n",
    "        n_trains = dc[\n",
    "            use_ppu_device, 'trainTrigger.numberOfTrains'\n",
    "        ].select_trains(by_id[[train_id]]).ndarray()[0]\n",
    "        train_ids.extend(list(range(train_id, train_id + n_trains)))\n",
    "\n",
    "    print(f'PPU device {use_ppu_device} triggered for {len(train_ids)} train(s)')\n",
    "\n",
    "elif train_ids != [-1]:\n",
    "    # Specific trains passed by parameter, convert to ndarray.\n",
    "    train_ids = np.array(train_ids)\n",
    "    \n",
    "    print(f'Processing up to {len(train_ids)} manually selected train(s)')\n",
    "else:\n",
    "    # Process all trains.\n",
    "    train_ids = None\n",
    "    \n",
    "    print(f'Processing all valid trains')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set everything up filewise\n",
    "mapped_files, _, total_sequences, _, _ =  map_modules_from_folder(\n",
    "    str(in_folder), run, path_template, karabo_da, sequences\n",
    ")\n",
    "file_list = []\n",
    "\n",
    "# ToDo: Split table over pages\n",
    "print(f\"Processing a total of {total_sequences} sequence files in chunks of {n_cores_files}\")\n",
    "table = []\n",
    "ti = 0\n",
    "for k, files in mapped_files.items():\n",
    "    i = 0\n",
    "    for f in list(files.queue):\n",
    "        file_list.append(f)\n",
    "        if i == 0:\n",
    "            table.append((ti, k, i, f))\n",
    "        else:\n",
    "            table.append((ti, \"\", i,  f))\n",
    "        i += 1\n",
    "        ti += 1\n",
    "md = display(Latex(tabulate.tabulate(table, tablefmt='latex',\n",
    "                                     headers=[\"#\", \"module\", \"# module\", \"file\"])))\n",
    "file_list = sorted(file_list, key=lambda name: name[-10:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "first_mod_channel = sorted(modules)[0]\n",
    "\n",
    "instrument_src_mod = [\n",
    "    s for s in list(dc.all_sources) if f\"{first_mod_channel}CH\" in s][0]\n",
    "mod_channel = int(re.findall(rf\".*{first_mod_channel}CH([0-9]+):.*\", instrument_src_mod)[0])\n",
    "\n",
    "agipd_cond = AgipdCtrl(\n",
    "    run_dc=dc,\n",
    "    image_src=instrument_src_mod,\n",
    "    ctrl_src=ctrl_src,\n",
    "    raise_error=False,  # to be able to process very old data without gain_setting value\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run's creation time:\n",
    "creation_time = calcat_creation_time(in_folder, run, creation_time)\n",
    "offset = parser.parse(creation_date_offset)\n",
    "delta = timedelta(hours=offset.hour, minutes=offset.minute, seconds=offset.second)\n",
    "creation_time += delta\n",
    "print(f\"Creation time: {creation_time}\")\n",
    "\n",
    "if acq_rate == -1.:\n",
    "    acq_rate = agipd_cond.get_acq_rate()\n",
    "if mem_cells == -1:\n",
    "    mem_cells = agipd_cond.get_num_cells()\n",
    "# TODO: look for alternative for passing creation_time\n",
    "if gain_setting == -1:\n",
    "    gain_setting = agipd_cond.get_gain_setting(creation_time)\n",
    "if bias_voltage == -1:\n",
    "    bias_voltage = agipd_cond.get_bias_voltage(karabo_id_control)\n",
    "if integration_time == -1:\n",
    "    integration_time = agipd_cond.get_integration_time()\n",
    "if gain_mode == -1:\n",
    "    gain_mode = agipd_cond.get_gain_mode()\n",
    "else:\n",
    "    gain_mode = AgipdGainMode(gain_mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if mem_cells is None:\n",
    "    raise ValueError(f\"No raw images found for {instrument_src_mod}\")\n",
    "\n",
    "mem_cells_db = mem_cells if mem_cells_db == -1 else mem_cells_db\n",
    "\n",
    "print(f\"Maximum memory cells to calibrate: {mem_cells}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Using {creation_time} as creation time\")\n",
    "print(\"Operating conditions are:\")\n",
    "print(f\"• Bias voltage: {bias_voltage}\")\n",
    "print(f\"• Memory cells: {mem_cells_db}\")\n",
    "print(f\"• Acquisition rate: {acq_rate}\")\n",
    "print(f\"• Gain setting: {gain_setting}\")\n",
    "print(f\"• Gain mode: {gain_mode.name}\")\n",
    "print(f\"• Integration time: {integration_time}\")\n",
    "print(f\"• Photon Energy: 9.2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if gain_mode:\n",
    "    for to_disable in disable_for_fixed_gain:\n",
    "        if corr_bools.get(to_disable, False):\n",
    "            warning(f\"{to_disable} correction was requested, but does not apply to fixed gain mode\")\n",
    "            corr_bools[to_disable] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if use_litframe_finder != 'off':\n",
    "    from extra_redu import make_litframe_finder, LitFrameFinderError\n",
    "    \n",
    "    if use_litframe_finder not in ['auto', 'offline', 'online']:\n",
    "        raise ValueError(\"Unexpected value in 'use_litframe_finder'.\")\n",
    "\n",
    "    inst = karabo_id_control[:3]\n",
    "    litfrm = make_litframe_finder(inst, dc, litframe_device_id)\n",
    "    try:\n",
    "        get_data = {'auto': litfrm.read_or_process, 'offline': litfrm.process, 'online': litfrm.read}\n",
    "        r = get_data[use_litframe_finder]()\n",
    "        cell_sel = LitFrameSelection(r, train_ids, max_pulses, energy_threshold, use_super_selection)\n",
    "        cell_sel.print_report()\n",
    "    except LitFrameFinderError as err:\n",
    "        warning(f\"Cannot use AgipdLitFrameFinder due to:\\n{err}\")\n",
    "        cell_sel = CellRange(max_pulses, max_cells=mem_cells)\n",
    "else:\n",
    "    # Use range selection\n",
    "    cell_sel = CellRange(max_pulses, max_cells=mem_cells)\n",
    "\n",
    "print(cell_sel.msg())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if round_photons and photon_energy <= 0.0:\n",
    "    if use_xgm_device:\n",
    "        # Try to obtain photon energy from XGM device.\n",
    "        wavelength_data = dc[use_xgm_device, 'pulseEnergy.wavelengthUsed']\n",
    "\n",
    "        try:\n",
    "            from scipy.constants import h, c, e\n",
    "\n",
    "            # Read wavelength as a single value and convert to hv.\n",
    "            photon_energy = (h * c / e) / (wavelength_data.as_single_value(rtol=1e-2) * 1e-6)\n",
    "            print(f'Obtained photon energy {photon_energy:.3f} keV from {use_xgm_device}')\n",
    "        except ValueError:\n",
    "            warning('XGM source available but photon energy varies greater than 1%, '\n",
    "                 'photon rounding disabled!')\n",
    "            round_photons = False\n",
    "    else:\n",
    "        warning('Neither explicit photon energy nor XGM device configured, photon rounding disabled!')\n",
    "        round_photons = False\n",
    "elif round_photons:\n",
    "    print(f'Photon energy for rounding: {photon_energy:.3f} keV')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data processing ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agipd_corr = AgipdCorrections(\n",
    "    mem_cells,\n",
    "    cell_sel,\n",
    "    h5_data_path=instrument_src,\n",
    "    h5_index_path=index_src,\n",
    "    corr_bools=corr_bools,\n",
    "    gain_mode=gain_mode,\n",
    "    comp_threads=os.cpu_count() // n_cores_files,\n",
    "    train_ids=train_ids\n",
    ")\n",
    "\n",
    "agipd_corr.baseline_corr_noise_threshold = -blc_noise_threshold\n",
    "agipd_corr.hg_hard_threshold = hg_hard_threshold\n",
    "agipd_corr.mg_hard_threshold = mg_hard_threshold\n",
    "\n",
    "agipd_corr.cm_dark_min = cm_dark_range[0]\n",
    "agipd_corr.cm_dark_max = cm_dark_range[1]\n",
    "agipd_corr.cm_dark_fraction = cm_dark_fraction\n",
    "agipd_corr.cm_n_itr = cm_n_itr\n",
    "agipd_corr.noisy_adc_threshold = noisy_adc_threshold\n",
    "agipd_corr.ff_gain = ff_gain\n",
    "agipd_corr.photon_energy = photon_energy\n",
    "\n",
    "agipd_corr.compress_fields = compress_fields\n",
    "if recast_image_data:\n",
    "    agipd_corr.recast_image_fields['data'] = np.dtype(recast_image_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "module_index_to_karabo_da = {mod: da for (mod, da) in zip(modules, karabo_da)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Retrieve calibration constants to RAM\n",
    "agipd_corr.allocate_constants(modules, (3, mem_cells_db, 512, 128))\n",
    "\n",
    "metadata = CalibrationMetadata(metadata_folder or out_folder)\n",
    "# NOTE: this notebook will not overwrite calibration metadata file\n",
    "const_yaml = metadata.get(\"retrieved-constants\", {})\n",
    "\n",
    "def retrieve_constants(mod):\n",
    "    \"\"\"\n",
    "    Retrieve calibration constants and load them to shared memory\n",
    "\n",
    "    Metadata for constants is taken from yml file or retrieved from the DB\n",
    "    \"\"\"\n",
    "    k_da = module_index_to_karabo_da[mod]\n",
    "    # check if there is a yaml file in out_folder that has the device constants.\n",
    "    if k_da in const_yaml:\n",
    "        when = agipd_corr.initialize_from_yaml(k_da, const_yaml, mod)\n",
    "        print(f\"Found constants for {k_da} in calibration_metadata.yml\")\n",
    "    else:\n",
    "        try:\n",
    "            # TODO: replace with proper retrieval (as done in pre-correction)\n",
    "            when = agipd_corr.initialize_from_db(\n",
    "                karabo_id=karabo_id,\n",
    "                karabo_da=k_da,\n",
    "                cal_db_interface=cal_db_interface,\n",
    "                creation_time=creation_time,\n",
    "                memory_cells=mem_cells_db,\n",
    "                bias_voltage=bias_voltage,\n",
    "                photon_energy=9.2,\n",
    "                gain_setting=gain_setting,\n",
    "                acquisition_rate=acq_rate,\n",
    "                integration_time=integration_time,\n",
    "                module_idx=mod,\n",
    "                only_dark=False,\n",
    "            )\n",
    "            print(f\"Queried CalCat for {k_da}\")\n",
    "        except Exception as e:\n",
    "            warning(f\"Module: {k_da}, {e}\")\n",
    "            when = None\n",
    "    return mod, when, k_da\n",
    "\n",
    "\n",
    "print(f'Preparing constants (FF: {agipd_corr.corr_bools.get(\"xray_corr\", False)}, PC: {any(agipd_corr.pc_bools)}, '\n",
    "      f'BLC: {any(agipd_corr.blc_bools)})')\n",
    "ts = perf_counter()\n",
    "with multiprocessing.Pool(processes=len(modules)) as pool:\n",
    "    const_out = pool.map(retrieve_constants, modules)\n",
    "print(f\"Constants were loaded in {perf_counter()-ts:.01f}s\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# allocate memory for images and hists\n",
    "n_images_max = mem_cells * 256\n",
    "data_shape = (n_images_max, 512, 128)\n",
    "agipd_corr.allocate_images(data_shape, n_cores_files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def batches(l, batch_size):\n",
    "    \"\"\"Group a list into batches of (up to) batch_size elements\"\"\"\n",
    "    start = 0\n",
    "    while start < len(l):\n",
    "        yield l[start:start + batch_size]\n",
    "        start += batch_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def imagewise_chunks(img_counts):\n",
    "    \"\"\"Break up the loaded data into chunks of up to chunk_size\n",
    "\n",
    "    Yields (file data slot, start index, stop index)\n",
    "    \"\"\"\n",
    "    \n",
    "    \n",
    "    for i_proc, n_img in enumerate(img_counts):\n",
    "        n_chunks = math.ceil(n_img / chunk_size)\n",
    "        for i in range(n_chunks):\n",
    "            yield i_proc, i * n_img // n_chunks, (i+1) * n_img // n_chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "step_timer = StepTimer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "step_timer.start()\n",
    "if max_tasks_per_worker == -1:\n",
    "    max_tasks_per_worker = None\n",
    "with multiprocessing.Pool(maxtasksperchild=max_tasks_per_worker) as pool:\n",
    "    step_timer.done_step('Started pool')\n",
    "    \n",
    "    for file_batch in batches(file_list, n_cores_files):\n",
    "        # TODO: Move some printed output to logging or similar\n",
    "        print(f\"Processing next {len(file_batch)} files\")\n",
    "        step_timer.start()\n",
    "        img_counts = pool.starmap(\n",
    "            agipd_corr.read_file,\n",
    "            zip(range(len(file_batch)), file_batch, [not common_mode]*len(file_batch))\n",
    "        )\n",
    "        step_timer.done_step(f'Loading data from files')\n",
    "\n",
    "        if img_counts == 0:\n",
    "            # Skip any further processing and output if there are no images to\n",
    "            # correct in this file.\n",
    "            continue\n",
    "\n",
    "        if mask_zero_std:\n",
    "            # Evaluate zero-data-std mask\n",
    "            pool.starmap(\n",
    "                agipd_corr.mask_zero_std, itertools.product(\n",
    "                    range(len(file_batch)),\n",
    "                    np.array_split(np.arange(agipd_corr.max_cells), n_cores_correct)\n",
    "                )\n",
    "            )\n",
    "            step_timer.done_step('Mask 0 std')\n",
    "\n",
    "        # Perform offset image-wise correction\n",
    "        pool.starmap(agipd_corr.offset_correction, imagewise_chunks(img_counts))\n",
    "        step_timer.done_step(\"Offset correction\")\n",
    "\n",
    "        if blc_noise or blc_stripes or blc_hmatch:\n",
    "            # Perform image-wise correction\n",
    "            pool.starmap(agipd_corr.baseline_correction, imagewise_chunks(img_counts))\n",
    "            step_timer.done_step(\"Base-line shift correction\")\n",
    "\n",
    "        if common_mode:\n",
    "            # In common mode corrected is enabled.\n",
    "            # Cell selection is only activated after common mode correction.\n",
    "            # Perform cross-file correction parallel over asics\n",
    "            image_files_idx = [i_proc for i_proc, n_img in enumerate(img_counts) if n_img > 0]\n",
    "            pool.starmap(agipd_corr.cm_correction, itertools.product(\n",
    "                image_files_idx, range(16)  # 16 ASICs per module\n",
    "            ))\n",
    "            step_timer.done_step(\"Common-mode correction\")\n",
    "\n",
    "            img_counts = pool.map(agipd_corr.apply_selected_pulses, image_files_idx)\n",
    "            step_timer.done_step(\"Applying selected cells after common mode correction\")\n",
    "            \n",
    "        # Perform image-wise correction\"\n",
    "        pool.starmap(agipd_corr.gain_correction, imagewise_chunks(img_counts))\n",
    "        step_timer.done_step(\"Gain corrections\")\n",
    "\n",
    "        # Save corrected data\n",
    "        pool.starmap(agipd_corr.write_file, [\n",
    "            (i_proc, file_name, str(out_folder / Path(file_name).name.replace(\"RAW\", \"CORR\")))\n",
    "            for i_proc, file_name in enumerate(file_batch)\n",
    "        ])\n",
    "        step_timer.done_step(\"Save\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Correction of {len(file_list)} files is finished\")\n",
    "print(f\"Total processing time {step_timer.timespan():.01f} s\")\n",
    "print(f\"Timing summary per batch of {n_cores_files} files:\")\n",
    "step_timer.print_summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if the yml file contains \"retrieved-constants\", that means a leading\n",
    "# notebook got processed and the reporting would be generated from it.\n",
    "fst_print = True\n",
    "timestamps = {}\n",
    "\n",
    "for i, (modno, when, k_da) in enumerate(const_out):\n",
    "    qm = module_index_to_qm(modno)\n",
    "\n",
    "    if k_da not in const_yaml:\n",
    "        if fst_print:\n",
    "            print(\"Constants are retrieved with creation time: \")\n",
    "            fst_print = False\n",
    "\n",
    "        module_timestamps = {}\n",
    "\n",
    "        print(f\"{qm}:\")\n",
    "        for key, item in when.items():\n",
    "            if hasattr(item, 'strftime'):\n",
    "                item = item.strftime('%y-%m-%d %H:%M')\n",
    "            when[key] = item\n",
    "            print('{:.<12s}'.format(key), item)\n",
    "\n",
    "        # Store few time stamps if exists\n",
    "        # Add NA to keep array structure\n",
    "        for key in ['Offset', 'SlopesPC', 'SlopesFF']:\n",
    "            if when and key in when and when[key]:\n",
    "                module_timestamps[key] = when[key]\n",
    "            else:\n",
    "                module_timestamps[key] = \"NA\"\n",
    "        timestamps[qm] = module_timestamps\n",
    "\n",
    "seq = sequences[0] if sequences else 0\n",
    "\n",
    "if timestamps:\n",
    "    with open(f\"{out_folder}/retrieved_constants_s{seq}.yml\",\"w\") as fd:\n",
    "        yaml.safe_dump({\"time-summary\": {f\"S{seq}\": timestamps}}, fd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if skip_plots:\n",
    "    print('Skipping plots')\n",
    "    import sys\n",
    "    sys.exit(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def do_3d_plot(data, edges, x_axis, y_axis):\n",
    "    fig = plt.figure(figsize=(10, 10))\n",
    "    ax = fig.gca(projection='3d')\n",
    "\n",
    "    # Make data.\n",
    "    X = edges[0][:-1]\n",
    "    Y = edges[1][:-1]\n",
    "    X, Y = np.meshgrid(X, Y)\n",
    "    Z = data.T\n",
    "\n",
    "    # Plot the surface.\n",
    "    ax.plot_surface(X, Y, Z, cmap=colormap.coolwarm, linewidth=0, antialiased=False)\n",
    "    ax.set_xlabel(x_axis)\n",
    "    ax.set_ylabel(y_axis)\n",
    "    ax.set_zlabel(\"Counts\")\n",
    "\n",
    "\n",
    "def do_2d_plot(data, edges, y_axis, x_axis):\n",
    "    fig = plt.figure(figsize=(10, 10))\n",
    "    ax = fig.add_subplot(111)\n",
    "    extent = [np.min(edges[1]), np.max(edges[1]),\n",
    "              np.min(edges[0]), np.max(edges[0])]\n",
    "    im = ax.imshow(data[::-1, :], extent=extent, aspect=\"auto\",\n",
    "                   norm=LogNorm(vmin=1, vmax=max(10, np.max(data))))\n",
    "    ax.set_xlabel(x_axis)\n",
    "    ax.set_ylabel(y_axis)\n",
    "    cb = fig.colorbar(im)\n",
    "    cb.set_label(\"Counts\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_trains_data(data_folder, source, include, detector_id, tid=None, modules=16, fillvalue=None):\n",
    "    \"\"\"Load single train for all module\n",
    "\n",
    "    :param data_folder: Path to folder with data\n",
    "    :param source: Data source to be loaded\n",
    "    :param include: Inset of file name to be considered\n",
    "    :param detector_id: The karabo id of the detector to get data for\n",
    "    :param tid: Train Id to be loaded. First train is considered if None is given\n",
    "    :param path: Path to find image data inside h5 file\n",
    "    \"\"\"\n",
    "    try:\n",
    "        run_data = RunDirectory(data_folder, include)\n",
    "    except FileNotFoundError:\n",
    "        warning(f'No corrected files for {include}. Skipping plots.')\n",
    "        import sys\n",
    "        sys.exit(0)\n",
    "    if tid is not None:\n",
    "        tid, data = run_data.select(\n",
    "            f'{detector_id}/DET/*', source).train_from_id(tid, keep_dims=True)\n",
    "    else:\n",
    "        # A first full trainId for all available modules is of interest.\n",
    "        tid, data = next(run_data.select(\n",
    "            f'{detector_id}/DET/*', source).trains(require_all=True, keep_dims=True))\n",
    "\n",
    "    stacked_data = stack_detector_data(\n",
    "        train=data, data=source, fillvalue=fillvalue, modules=modules)\n",
    "\n",
    "    return tid, stacked_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if dinstance == \"AGIPD500K\":\n",
    "    geom = AGIPD_500K2GGeometry.from_origin()\n",
    "else:\n",
    "    geom = AGIPD_1MGeometry.from_quad_positions(quad_pos=[\n",
    "        (-525, 625),\n",
    "        (-550, -10),\n",
    "        (520, -160),\n",
    "        (542.5, 475),\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "include = '*S00000*' if sequences is None else f'*S{sequences[0]:05d}*'\n",
    "tid, corrected = get_trains_data(out_folder, 'image.data', include, karabo_id, modules=nmods)\n",
    "\n",
    "_, gains = get_trains_data(out_folder, 'image.gain', include, karabo_id, tid, modules=nmods)\n",
    "_, mask = get_trains_data(out_folder, 'image.mask', include, karabo_id, tid, modules=nmods)\n",
    "_, blshift = get_trains_data(out_folder, 'image.blShift', include, karabo_id, tid, modules=nmods)\n",
    "_, cellId = get_trains_data(out_folder, 'image.cellId', include, karabo_id, tid, modules=nmods)\n",
    "_, pulseId = get_trains_data(out_folder, 'image.pulseId', include, karabo_id, tid, modules=nmods, fillvalue=0)\n",
    "_, raw = get_trains_data(run_folder, 'image.data', include, karabo_id, tid, modules=nmods)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(f'## Preview and statistics for {gains.shape[0]} images of the train {tid} ##\\n'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Signal vs. Analogue Gain ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hist, bins_x, bins_y = calgs.histogram2d(raw[:,0,...].flatten().astype(np.float32),\n",
    "                                         raw[:,1,...].flatten().astype(np.float32),\n",
    "                                         bins=(100, 100),\n",
    "                                         range=[[4000, 8192], [4000, 8192]])\n",
    "do_2d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Analogue gain (ADU)\")\n",
    "do_3d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Analogue gain (ADU)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Signal vs. Digitized Gain ###\n",
    "\n",
    "The following plot shows plots signal vs. digitized gain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hist, bins_x, bins_y = calgs.histogram2d(corrected.flatten().astype(np.float32),\n",
    "                                         gains.flatten().astype(np.float32), bins=(100, 3),\n",
    "                                         range=[[-50, 8192], [0, 3]])\n",
    "do_2d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Gain bit value\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Gain statistics in %\")\n",
    "table = [[f'{gains[gains==0].size/gains.size*100:.02f}',\n",
    "          f'{gains[gains==1].size/gains.size*100:.03f}',\n",
    "          f'{gains[gains==2].size/gains.size*100:.03f}']]\n",
    "md = display(Latex(tabulate.tabulate(table, tablefmt='latex',\n",
    "                                     headers=[\"High\", \"Medium\", \"Low\"])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Intensity per Pulse ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "pulse_range = [np.min(pulseId[pulseId>=0]), np.max(pulseId[pulseId>=0])]\n",
    "\n",
    "# Modify pulse_range, if only one pulse is selected.\n",
    "if pulse_range[0] == pulse_range[1]:\n",
    "    pulse_range = [0, pulse_range[1]+int(acq_rate)]\n",
    "\n",
    "mean_data = np.nanmean(corrected, axis=(2, 3))\n",
    "hist, bins_x, bins_y = calgs.histogram2d(mean_data.flatten().astype(np.float32),\n",
    "                                      pulseId.flatten().astype(np.float32),\n",
    "                                      bins=(100, int(pulse_range[1])),\n",
    "                                      range=[[-50, 1000], pulse_range])\n",
    "\n",
    "do_2d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Pulse id\")\n",
    "do_3d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Pulse id\")\n",
    "\n",
    "hist, bins_x, bins_y = calgs.histogram2d(mean_data.flatten().astype(np.float32),\n",
    "                                      pulseId.flatten().astype(np.float32),\n",
    "                                      bins=(100,  int(pulse_range[1])),\n",
    "                                      range=[[-50, 200000], pulse_range])\n",
    "\n",
    "do_2d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Pulse id\")\n",
    "do_3d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Pulse id\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Baseline shift ###\n",
    "\n",
    "Estimated base-line shift with respect to the total ADU counts of corrected image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "h = ax.hist(blshift.flatten(), bins=100, log=True)\n",
    "_ = plt.xlabel('Baseline shift [ADU]')\n",
    "_ = plt.ylabel('Counts')\n",
    "_ = ax.grid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 10))\n",
    "corrected_ave = np.nansum(corrected, axis=(2, 3))\n",
    "plt.scatter(corrected_ave.flatten()/10**6, blshift.flatten(), s=0.9)\n",
    "plt.xlim(-1, 1000)\n",
    "plt.grid()\n",
    "plt.xlabel('Illuminated corrected [MADU] ')\n",
    "_ = plt.ylabel('Estimated baseline shift [ADU]')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if cell_id_preview not in cellId[:, 0]:\n",
    "    print(f\"WARNING: The selected cell_id_preview value {cell_id_preview} is not available in the corrected data.\")\n",
    "    cell_id_preview = cellId[:, 0][0]\n",
    "    cell_idx_preview = 0\n",
    "    print(f\"Previewing the first available cellId: {cell_id_preview}.\")\n",
    "else:\n",
    "    cell_idx_preview = np.where(cellId[:, 0] == cell_id_preview)[0][0] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown('### Raw preview ###\\n'))\n",
    "if cellId.shape[0] != 1:\n",
    "    display(Markdown(f'Mean over images of the RAW data\\n'))\n",
    "    fig = plt.figure(figsize=(20, 10))\n",
    "    ax = fig.add_subplot(111)\n",
    "    data = np.mean(raw[slice(*cell_sel.crange), 0, ...], axis=0)\n",
    "    vmin, vmax = get_range(data, 5)\n",
    "    ax = geom.plot_data_fast(data, ax=ax, cmap=\"jet\", vmin=vmin, vmax=vmax)\n",
    "else:\n",
    "    print(\"Skipping mean RAW preview for single memory cell, \"\n",
    "          f\"see single shot image for selected cell ID {cell_id_preview}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(f'Single shot of the RAW data from cell {cell_id_preview} \\n'))\n",
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "vmin, vmax = get_range(raw[cell_idx_preview, 0, ...], 5)\n",
    "ax = geom.plot_data_fast(raw[cell_idx_preview, 0, ...], ax=ax, cmap=\"jet\", vmin=vmin, vmax=vmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown('### Corrected preview ###\\n'))\n",
    "if cellId.shape[0] != 1:\n",
    "    display(Markdown('### Mean CORRECTED Preview ###\\n'))\n",
    "    display(Markdown(f'A mean across train: {tid}\\n'))\n",
    "    fig = plt.figure(figsize=(20, 10))\n",
    "    ax = fig.add_subplot(111)\n",
    "    data = np.mean(corrected, axis=0)\n",
    "    vmin, vmax = get_range(data, 7)\n",
    "    ax = geom.plot_data_fast(data, ax=ax, cmap=\"jet\", vmin=-50, vmax=vmax)\n",
    "else:\n",
    "    print(\"Skipping mean CORRECTED preview for single memory cell, \"\n",
    "          f\"see single shot image for selected cell ID {cell_id_preview}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(f'A single shot of the CORRECTED image from cell {cell_id_preview} \\n'))\n",
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "vmin, vmax = get_range(corrected[cell_idx_preview], 7, -50)\n",
    "vmin = - 50\n",
    "ax = geom.plot_data_fast(corrected[cell_idx_preview], ax=ax, cmap=\"jet\", vmin=vmin, vmax=vmax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "vmin, vmax = get_range(corrected[cell_idx_preview], 5, -50)\n",
    "nbins = np.int((vmax + 50) / 2)\n",
    "h = ax.hist(corrected[cell_idx_preview].flatten(),\n",
    "            bins=nbins, range=(-50, vmax),\n",
    "            histtype='stepfilled', log=True)\n",
    "plt.xlabel('[ADU]')\n",
    "plt.ylabel('Counts')\n",
    "ax.grid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "vmin, vmax = get_range(corrected, 10, -100)\n",
    "vmax = np.nanmax(corrected)\n",
    "if vmax > 50000:\n",
    "    vmax=50000\n",
    "nbins = np.int((vmax + 100) / 5)\n",
    "h = ax.hist(corrected.flatten(), bins=nbins,\n",
    "            range=(-100, vmax), histtype='step', log=True, label = 'All')\n",
    "ax.hist(corrected[gains == 0].flatten(), bins=nbins, range=(-100, vmax),\n",
    "        alpha=0.5, log=True, label='High gain', color='green')\n",
    "ax.hist(corrected[gains == 1].flatten(), bins=nbins, range=(-100, vmax),\n",
    "        alpha=0.5, log=True, label='Medium gain', color='red')\n",
    "ax.hist(corrected[gains == 2].flatten(), bins=nbins, range=(-100, vmax),\n",
    "        alpha=0.5, log=True, label='Low gain', color='yellow')\n",
    "ax.legend()\n",
    "ax.grid()\n",
    "plt.xlabel('[ADU]')\n",
    "plt.ylabel('Counts')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown('### Maximum GAIN Preview ###\\n'))\n",
    "display(Markdown(f'The per pixel maximum across one train for the digitized gain'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "ax = geom.plot_data_fast(np.max(gains, axis=0), ax=ax,\n",
    "                         cmap=\"jet\", vmin=-1, vmax=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bad Pixels ##\n",
    "The mask contains dedicated entries for all pixels and memory cells as well as all three gains stages. Each mask entry is encoded in 32 bits as:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "table = []\n",
    "for item in BadPixels:\n",
    "    table.append((item.name, \"{:016b}\".format(item.value)))\n",
    "md = display(Latex(tabulate.tabulate(table, tablefmt='latex',\n",
    "                                     headers=[\"Bad pixel type\", \"Bit mask\"])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(f'### Single Shot Bad Pixels ### \\n'))\n",
    "display(Markdown(f'A single shot bad pixel map from cell {cell_id_preview} \\n'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "geom.plot_data_fast(np.log2(mask[cell_idx_preview]), ax=ax, vmin=0, vmax=32, cmap=\"jet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if round_photons:\n",
    "    display(Markdown('### Photonization histograms ###'))\n",
    "    \n",
    "    x_preround = (agipd_corr.hist_bins_preround[1:] + agipd_corr.hist_bins_preround[:-1]) / 2\n",
    "    x_postround = (agipd_corr.hist_bins_postround[1:] + agipd_corr.hist_bins_postround[:-1]) / 2\n",
    "    x_photons = np.arange(0, (x_postround[-1] + 1) / photon_energy)\n",
    "\n",
    "    fig, ax = plt.subplots(ncols=1, nrows=1, clear=True)\n",
    "    ax.plot(x_preround, agipd_corr.shared_hist_preround, '.-', color='C0')\n",
    "    ax.bar(x_postround, agipd_corr.shared_hist_postround, photon_energy, color='C1', alpha=0.5)\n",
    "    ax.set_yscale('log')\n",
    "    ax.set_ylim(0, max(agipd_corr.shared_hist_preround.max(), agipd_corr.shared_hist_postround.max())*3)\n",
    "    ax.set_xlim(x_postround[0], x_postround[-1]+1)\n",
    "    ax.set_xlabel('Photon energy / keV')\n",
    "    ax.set_ylabel('Intensity')\n",
    "    ax.vlines(x_photons * photon_energy, *ax.get_ylim(), color='k', linestyle='dashed')\n",
    "\n",
    "    phx = ax.twiny()\n",
    "    phx.set_xlim(x_postround[0] / photon_energy, (x_postround[-1]+1)/photon_energy)\n",
    "    phx.set_xticks(x_photons)\n",
    "    phx.set_xlabel('# Photons')\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Percentage of Bad Pixels across one train  ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "geom.plot_data_fast(np.mean(mask>0, axis=0), vmin=0, ax=ax, vmax=1, cmap=\"jet\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Percentage of Bad Pixels across one train. Only Dark Related ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "cm = np.copy(mask)\n",
    "cm[cm > BadPixels.NO_DARK_DATA.value] = 0\n",
    "ax = geom.plot_data_fast(np.mean(cm>0, axis=0),\n",
    "                         vmin=0, ax=ax, vmax=1, cmap=\"jet\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}