{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AGIPD Offline Correction #\n",
    "\n",
    "Author: European XFEL Detector Group, Version: 2.0\n",
    "\n",
    "Offline Calibration for the AGIPD Detector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_folder = \"/gpfs/exfel/exp/MID/202201/p002834/raw\" # the folder to read data from, required\n",
    "out_folder = \"/gpfs/exfel/data/scratch/esobolev/pycal_litfrm/p002834/r0225\"  # the folder to output to, required\n",
    "metadata_folder = \"\"  # Directory containing calibration_metadata.yml when run by xfel-calibrate\n",
    "sequences = [-1] # sequences to correct, set to -1 for all, range allowed\n",
    "overwrite = False  # IGNORED, NEEDED FOR COMPATIBILITY.\n",
    "modules = [-1] # modules to correct, set to -1 for all, range allowed\n",
    "train_ids = [-1] # train IDs to correct, set to -1 for all, range allowed\n",
    "run = 225 # runs to process, required\n",
    "\n",
    "karabo_id = \"MID_DET_AGIPD1M-1\" # karabo karabo_id\n",
    "karabo_da = ['-1']  # a list of data aggregators names, Default [-1] for selecting all data aggregators\n",
    "receiver_template = \"{}CH0\" # inset for receiver devices\n",
    "path_template = 'RAW-R{:04d}-{}-S{:05d}.h5' # the template to use to access data\n",
    "instrument_source_template = '{}/DET/{}:xtdf'  # path in the HDF5 file to images\n",
    "index_source_template = 'INDEX/{}/DET/{}:xtdf/'  # path in the HDF5 file to images\n",
    "ctrl_source_template = '{}/MDL/FPGA_COMP'  # path to control information\n",
    "karabo_id_control = \"MID_EXP_AGIPD1M1\" # karabo-id for control device\n",
    "\n",
    "slopes_ff_from_files = \"\" # Path to locally stored SlopesFF and BadPixelsFF constants, loaded in precorrection notebook\n",
    "\n",
    "creation_time = \"\"  # To overwrite the measured creation_time. Required Format: YYYY-MM-DD HR:MN:SC e.g. \"2022-06-28 13:00:00\"\n",
    "cal_db_interface = \"tcp://max-exfl-cal001:8015#8045\" # the database interface to use\n",
    "cal_db_timeout = 30000 # in milliseconds\n",
    "creation_date_offset = \"00:00:00\" # add an offset to creation date, e.g. to get different constants\n",
    "cal_db_root = \"\"  # The calibration database root path to access constant files. e.g. accessing constants from the test database /gpfs/exfel/d/cal_tst/caldb_store.\n",
    "\n",
    "mem_cells = -1  # Number of memory cells used, set to 0 to automatically infer\n",
    "bias_voltage = -1  # bias voltage, set to 0 to use stored value in slow data.\n",
    "acq_rate = -1. # the detector acquisition rate, use 0 to try to auto-determine\n",
    "gain_setting = -1  # the gain setting, use -1 to use value stored in slow data.\n",
    "gain_mode = -1  # gain mode (0: adaptive, 1-3 fixed high/med/low, -1: read from CONTROL data)\n",
    "max_pulses = [0, 352, 1] # range list [st, end, step] of memory cell indices to be processed within a train. 3 allowed maximum list input elements.\n",
    "mem_cells_db = -1  # set to a value different than 0 to use this value for DB queries\n",
    "integration_time = -1 # integration time, negative values for auto-detection.\n",
    "\n",
    "# Correction parameters\n",
    "blc_noise_threshold = 5000 # above this mean signal intensity now baseline correction via noise is attempted\n",
    "cm_dark_fraction = 0.66 # threshold for fraction of  empty pixels to consider module enough dark to perform CM correction\n",
    "cm_dark_range = [-50.,30] # range for signal value ADU for pixel to be consider as a dark pixel\n",
    "cm_n_itr = 4 # number of iterations for common mode correction\n",
    "hg_hard_threshold = 1000 # threshold to force medium gain offset subtracted pixel to high gain\n",
    "mg_hard_threshold = 1000 # threshold to force medium gain offset subtracted pixel from low to medium gain\n",
    "noisy_adc_threshold = 0.25 # threshold to mask complete adc\n",
    "ff_gain = 7.2 # conversion gain for absolute FlatField constants, while applying xray_gain\n",
    "photon_energy = -1.0 # photon energy in keV, non-positive value for XGM autodetection\n",
    "rounding_threshold = 0.5 # the fraction to round to down, 0.5 for standard rounding rule\n",
    "cs_mg_adjust = 7e3  # Value to adjust medium gain when correcting with current source. This is used when `adjust_mg_baseline` is True.\n",
    "\n",
    "# Correction Booleans\n",
    "only_offset = False # Apply only Offset correction. if False, Offset is applied by Default. if True, Offset is only applied.\n",
    "# TODO: Remove this boolean parameter an replace rel_gain_mode with it.\n",
    "rel_gain = False \n",
    "rel_gain_mode = \"off\"  # Select relative gain correction. Choices [`PC`, `CS`, `off`]. (`PC`: Pulse Capacitor, `CS`: Current Source, `off`: Disable relative gain correction). Default: off.\n",
    "xray_gain = False # do relative gain correction based on xray data\n",
    "blc_noise = False # if set, baseline correction via noise peak location is attempted\n",
    "blc_stripes = False # if set, baseline corrected via stripes\n",
    "blc_hmatch = False # if set, base line correction via histogram matching is attempted\n",
    "match_asics = False # if set, inner ASIC borders are matched to the same signal level\n",
    "adjust_mg_baseline = False # adjust medium gain baseline to match highest high gain value\n",
    "zero_nans = False # set NaN values in corrected data to 0\n",
    "zero_orange = False # set to 0 very negative and very large values in corrected data\n",
    "blc_set_min = False # Shift to 0 negative medium gain pixels after offset corr\n",
    "corr_asic_diag = False # if set, diagonal drop offs on ASICs are corrected\n",
    "force_hg_if_below = False # set high gain if mg offset subtracted value is below hg_hard_threshold\n",
    "force_mg_if_below = False # set medium gain if mg offset subtracted value is below mg_hard_threshold\n",
    "mask_noisy_adc = False # Mask entire ADC if they are noise above a relative threshold\n",
    "common_mode = False # Common mode correction\n",
    "melt_snow = False # Identify (and optionally interpolate) 'snowy' pixels\n",
    "mask_zero_std = False # Mask pixels with zero standard deviation across train\n",
    "low_medium_gap = False # 5 sigma separation in thresholding between low and medium gain\n",
    "round_photons = False  # Round to absolute number of photons, only use with gain corrections\n",
    "\n",
    "# Additional processing\n",
    "count_lit_pixels = False # Count the number of pixels registering photons\n",
    "spi_hitfinding = False  # Find hits using lit-pixel counter\n",
    "\n",
    "# Optional auxiliary devices\n",
    "use_ppu_device = ''  # Device ID for a pulse picker device to only process picked trains, empty string to disable\n",
    "ppu_train_offset = 0  # When using the pulse picker, offset between the PPU's sequence start and actually picked train\n",
    "require_ppu_trigger = False  # Optional protection against running without PPU or without triggering trains.\n",
    "\n",
    "use_litframe_finder = 'off' # Process only illuminated frames: 'off' - disable, 'device' - use online device data, 'offline' - use offline algorithm, 'auto' - choose online/offline source automatically (default)\n",
    "litframe_device_id = '' # Device ID for a lit frame finder device, empty string to auto detection\n",
    "energy_threshold = -1000 # The low limit for the energy (uJ) exposed by frames subject to processing. If -1000, selection by pulse energy is disabled\n",
    "use_super_selection = 'cm' # Make a common selection for entire run: 'off' - disable, 'final' - enable for final selection, 'cm' - enable only for common mode correction\n",
    "\n",
    "use_xgm_device = ''  # DoocsXGM device ID to obtain actual photon energy, operating condition else.\n",
    "\n",
    "# Output parameters\n",
    "recast_image_data = ''  # Cast data to a different dtype before saving\n",
    "compress_fields = ['gain', 'mask']  # Datasets in image group to compress.\n",
    "\n",
    "# SPI hit-finder parameters\n",
    "spi_hf_modules = [3, 4, 8, 15]  # Use specified modules for SPI hitfinding\n",
    "spi_hf_mode = \"adaptive\"  # The method to compute threshold for hitscores in SPI hitfinding: `fixed` or `adaptive`\n",
    "spi_hf_snr = 4.0  # Siginal-to-noise ration for adaptive threshold in SPI hitfinding\n",
    "spi_hf_min_scores = 100  # The minimal size of events to compute adaptive threshold in SPI hitfinding\n",
    "spi_hf_fixed_threshold = 0  # The fixed threshold value\n",
    "spi_hf_hitrate_window_size = 200  # The window size for runnig average of hitrate in trains\n",
    "spi_hf_miss_fraction = 1  # The fraction of misses to select along with hits\n",
    "spi_hf_miss_fraction_base = \"hit\"  # The base to compute the number of misses to select: the number of hits (`hit`) or misses (`miss`)\n",
    "\n",
    "# Plotting parameters\n",
    "skip_plots = False # exit after writing corrected files and metadata\n",
    "cell_id_preview = 1 # cell Id used for preview in single-shot plots\n",
    "cmap = \"viridis\"  # matplolib.colormap for almost all heatmap. Other options ['plasma', 'inferno', 'magma', 'cividis', 'jet', ...]\n",
    "\n",
    "# Parallelization parameters\n",
    "chunk_size = 1000  # Size of chunk for image-wise correction\n",
    "n_cores_correct = 16 # Number of chunks to be processed in parallel\n",
    "n_cores_files = 4 # Number of files to be processed in parallel\n",
    "sequences_per_node = 2 # number of sequence files per cluster node if run as SLURM job, set to 0 to not run SLURM parallel\n",
    "max_nodes = 8 # Maximum number of SLURM jobs to split correction work into\n",
    "max_tasks_per_worker = 1  # the number of tasks a correction pool worker process can complete before it will exit and be replaced with a fresh worker process. Leave as -1 to keep worker alive as long as pool.\n",
    "\n",
    "\n",
    "def balance_sequences(in_folder, run, sequences, sequences_per_node, karabo_da, max_nodes):\n",
    "    from xfel_calibrate.calibrate import balance_sequences as bs\n",
    "    return bs(in_folder, run, sequences, sequences_per_node, karabo_da, max_nodes=max_nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import math\n",
    "import multiprocessing\n",
    "import os\n",
    "import sys\n",
    "import warnings\n",
    "from datetime import timedelta\n",
    "from logging import warning\n",
    "from pathlib import Path\n",
    "\n",
    "import tabulate\n",
    "from dateutil import parser\n",
    "from IPython.display import Latex, Markdown, display\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "import h5py\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import yaml\n",
    "from extra_data import by_id, RunDirectory, stack_detector_data\n",
    "from extra_geom import AGIPD_1MGeometry, AGIPD_500K2GGeometry\n",
    "from matplotlib.colors import LogNorm\n",
    "\n",
    "matplotlib.use(\"agg\")\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "\n",
    "sns.set()\n",
    "sns.set_context(\"paper\", font_scale=1.4)\n",
    "sns.set_style(\"ticks\")\n",
    "\n",
    "import cal_tools.restful_config as rest_cfg\n",
    "from cal_tools import agipdalgs as calgs\n",
    "from cal_tools.agipdlib import (\n",
    "    AgipdCorrections,\n",
    "    AgipdCtrl,\n",
    "    CellRange,\n",
    "    LitFrameSelection,\n",
    ")\n",
    "from cal_tools.ana_tools import get_range\n",
    "from cal_tools.calcat_interface import (\n",
    "    AGIPD_CalibrationData,\n",
    "    CalCatError,\n",
    ")\n",
    "from cal_tools.enums import AgipdGainMode, BadPixels\n",
    "from cal_tools.plotting import agipd_single_module_geometry\n",
    "from cal_tools.step_timing import StepTimer\n",
    "from cal_tools.tools import (\n",
    "    calcat_creation_time,\n",
    "    latex_warning,\n",
    "    map_modules_from_folder,\n",
    "    module_index_to_qm,\n",
    "    write_constants_fragment,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_folder = Path(in_folder)\n",
    "out_folder = Path(out_folder)\n",
    "run_folder = in_folder / f'r{run:04d}'\n",
    "\n",
    "step_timer = StepTimer()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluated parameters ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fill dictionaries comprising bools and arguments for correction and data analysis\n",
    "\n",
    "# Here the hierarchy and dependability for correction booleans are defined\n",
    "corr_bools = {}\n",
    "\n",
    "cs_corr = False\n",
    "pc_corr = False\n",
    "if rel_gain_mode.lower() == \"off\":\n",
    "    # TODO: Remove this part after replacing rel_gain with rel_gain_mode\n",
    "    if rel_gain:\n",
    "        pc_corr = True\n",
    "\n",
    "elif rel_gain_mode.lower() == \"cs\":\n",
    "    cs_corr = True\n",
    "\n",
    "elif rel_gain_mode.lower() == \"pc\":\n",
    "    pc_corr = True\n",
    "\n",
    "else:\n",
    "    raise ValueError(\n",
    "        \"Selected `rel_gain_mode` is unexpected. \"\n",
    "        \"Please select between CS or PC.\")\n",
    "\n",
    "# offset is at the bottom of AGIPD correction pyramid.\n",
    "corr_bools[\"only_offset\"] = only_offset\n",
    "\n",
    "# Dont apply any corrections if only_offset is requested\n",
    "if not only_offset:\n",
    "    corr_bools[\"cs_corr\"] = cs_corr\n",
    "    corr_bools[\"pc_corr\"] = pc_corr\n",
    "    corr_bools[\"adjust_mg_baseline\"] = adjust_mg_baseline\n",
    "    corr_bools[\"xray_corr\"] = xray_gain\n",
    "    corr_bools[\"blc_noise\"] = blc_noise\n",
    "    corr_bools[\"blc_stripes\"] = blc_stripes\n",
    "    corr_bools[\"blc_hmatch\"] = blc_hmatch\n",
    "    corr_bools[\"blc_set_min\"] = blc_set_min\n",
    "    corr_bools[\"match_asics\"] = match_asics\n",
    "    corr_bools[\"corr_asic_diag\"] = corr_asic_diag\n",
    "    corr_bools[\"zero_nans\"] = zero_nans\n",
    "    corr_bools[\"zero_orange\"] = zero_orange\n",
    "    corr_bools[\"mask_noisy_adc\"] = mask_noisy_adc\n",
    "    corr_bools[\"force_hg_if_below\"] = force_hg_if_below\n",
    "    corr_bools[\"force_mg_if_below\"] = force_mg_if_below\n",
    "    corr_bools[\"common_mode\"] = common_mode\n",
    "    corr_bools[\"melt_snow\"] = melt_snow\n",
    "    corr_bools[\"mask_zero_std\"] = mask_zero_std\n",
    "    corr_bools[\"low_medium_gap\"] = low_medium_gap\n",
    "    corr_bools[\"round_photons\"] = round_photons\n",
    "    corr_bools[\"count_lit_pixels\"] = count_lit_pixels\n",
    "\n",
    "# Many corrections don't apply to fixed gain mode; will explicitly disable later if detected\n",
    "disable_for_fixed_gain = [\n",
    "    \"adjust_mg_baseline\",\n",
    "    \"blc_set_min\",\n",
    "    \"force_hg_if_below\",\n",
    "    \"force_mg_if_below\",\n",
    "    \"low_medium_gap\",\n",
    "    \"melt_snow\",\n",
    "    \"pc_corr\",\n",
    "    \"cs_corr\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if sequences == [-1]:\n",
    "    sequences = None\n",
    "\n",
    "dc = RunDirectory(run_folder)\n",
    "\n",
    "ctrl_src = ctrl_source_template.format(karabo_id_control)\n",
    "instrument_src = instrument_source_template.format(karabo_id, receiver_template)\n",
    "index_src = index_source_template.format(karabo_id, receiver_template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create output folder\n",
    "out_folder.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "# Evaluate detector instance for mapping\n",
    "instrument = karabo_id.split(\"_\")[0]\n",
    "if \"AGIPD1M\" in karabo_id:\n",
    "    nmods = 16\n",
    "elif \"AGIPD500K\" in karabo_id:\n",
    "    nmods = 8\n",
    "else:\n",
    "    nmods = 1\n",
    "\n",
    "# Evaluate requested modules\n",
    "if karabo_da[0] == '-1':\n",
    "    if modules[0] == -1:\n",
    "        modules = list(range(nmods))\n",
    "    mod_indices = modules if nmods > 1 else [0]\n",
    "    karabo_da = [\"AGIPD{:02d}\".format(i) for i in modules]\n",
    "else:  # TODO: fix this with the new CALCAT metadata for module indices.\n",
    "    modules = [int(x[-2:]) for x in karabo_da]\n",
    "    mod_indices = modules if nmods > 1 else [0]\n",
    "\n",
    "print(\"Process modules:\", ', '.join(module_index_to_qm(x) for x in mod_indices))\n",
    "print(f\"Detector in use is {karabo_id}\")\n",
    "print(f\"Instrument {instrument}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_available = False\n",
    "\n",
    "for m in modules:\n",
    "    try:\n",
    "        # Attempt to select the module. If no trains are available, ValueError might be raised\n",
    "        if len(dc[instrument_src.format(m), 'image.data'].drop_empty_trains().train_ids) > 0:\n",
    "            train_available = True\n",
    "            break  # Found a module with available trains.\n",
    "    except ValueError as e:\n",
    "        warning(f\"Missing module {m} from data: {e}\")\n",
    "\n",
    "if not train_available:\n",
    "    # Execute this block if no modules with trains were found.\n",
    "    latex_warning(\"No trains available to correct for selected modules.\")\n",
    "    sys.exit(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if use_ppu_device and use_ppu_device in dc.control_sources:\n",
    "    # Obtain trains to process if using a pulse picker device and it's present.\n",
    "\n",
    "    seq_start = dc[use_ppu_device, 'trainTrigger.sequenceStart.value'].ndarray()\n",
    "\n",
    "    # The trains picked are the unique values of trainTrigger.sequenceStart\n",
    "    # minus the first (previous trigger before this run).\n",
    "    start_train_ids = np.unique(seq_start)[1:] + ppu_train_offset\n",
    "\n",
    "    train_ids = []\n",
    "    for train_id in start_train_ids:\n",
    "        n_trains = dc[\n",
    "            use_ppu_device, 'trainTrigger.numberOfTrains'\n",
    "        ].select_trains(by_id[[train_id]]).ndarray()[0]\n",
    "        train_ids.extend(list(range(train_id, train_id + n_trains)))\n",
    "\n",
    "    if train_ids:\n",
    "        print(f'PPU device {use_ppu_device} triggered for {len(train_ids)} train(s)')\n",
    "    elif require_ppu_trigger:\n",
    "        raise RuntimeError(f'PPU device {use_ppu_device} not triggered but required, aborting!')\n",
    "    else:\n",
    "        print(f'PPU device {use_ppu_device} not triggered, processing all valid trains')\n",
    "        train_ids = None\n",
    "        \n",
    "elif use_ppu_device:\n",
    "    # PPU configured but not present.\n",
    "    \n",
    "    if require_ppu_trigger:\n",
    "        raise RuntimeError(f'PPU device {use_ppu_device} required but not found, aborting!')\n",
    "    else:\n",
    "        print(f'PPU device {use_ppu_device} configured but not found, processing all valid trains')\n",
    "        train_ids = None\n",
    "\n",
    "elif train_ids != [-1]:\n",
    "    # Specific trains passed by parameter, convert to ndarray.\n",
    "    train_ids = np.array(train_ids)\n",
    "    \n",
    "    print(f'Processing up to {len(train_ids)} manually selected train(s)')\n",
    "\n",
    "else:\n",
    "    # No PPU configured.\n",
    "    print(f'Processing all valid trains')\n",
    "    train_ids = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set everything up filewise\n",
    "mapped_files, _, total_sequences, _, _ =  map_modules_from_folder(\n",
    "    str(in_folder), run, path_template, karabo_da, sequences\n",
    ")\n",
    "file_list = []\n",
    "\n",
    "# ToDo: Split table over pages\n",
    "print(f\"Processing a total of {total_sequences} sequence files in chunks of {n_cores_files}\")\n",
    "table = []\n",
    "ti = 0\n",
    "for k, files in mapped_files.items():\n",
    "    i = 0\n",
    "    for f in list(files.queue):\n",
    "        file_list.append(f)\n",
    "        if i == 0:\n",
    "            table.append((ti, k, i, f))\n",
    "        else:\n",
    "            table.append((ti, \"\", i,  f))\n",
    "        i += 1\n",
    "        ti += 1\n",
    "md = display(Latex(tabulate.tabulate(table, tablefmt='latex',\n",
    "                                     headers=[\"#\", \"module\", \"# module\", \"file\"])))\n",
    "file_list = sorted(file_list, key=lambda name: name[-10:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "first_mod_channel = sorted(modules)[0]\n",
    "instrument_src_first_mod = [\n",
    "    s for s in list(dc.all_sources) if f\"{first_mod_channel}CH\" in s][0]\n",
    "\n",
    "agipd_cond = AgipdCtrl(\n",
    "    run_dc=dc,\n",
    "    image_src=instrument_src_first_mod,\n",
    "    ctrl_src=ctrl_src,\n",
    "    raise_error=False,  # to be able to process very old data without gain_setting value\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run's creation time:\n",
    "creation_time = calcat_creation_time(in_folder, run, creation_time)\n",
    "offset = parser.parse(creation_date_offset)\n",
    "delta = timedelta(hours=offset.hour, minutes=offset.minute, seconds=offset.second)\n",
    "creation_time += delta\n",
    "print(f\"Creation time: {creation_time}\")\n",
    "\n",
    "if acq_rate == -1.:\n",
    "    acq_rate = agipd_cond.get_acq_rate()\n",
    "if mem_cells == -1:\n",
    "    mem_cells = agipd_cond.get_num_cells()\n",
    "# TODO: look for alternative for passing creation_time\n",
    "if gain_setting == -1:\n",
    "    gain_setting = agipd_cond.get_gain_setting(creation_time)\n",
    "if bias_voltage == -1:\n",
    "    for m in modules:\n",
    "        bias_voltage = agipd_cond.get_bias_voltage(karabo_id_control, module=m)\n",
    "        # Accept non-zero value for a bias voltage from any module.\n",
    "        if bias_voltage != 0.:\n",
    "            break\n",
    "if integration_time == -1:\n",
    "    integration_time = agipd_cond.get_integration_time()\n",
    "if gain_mode == -1:\n",
    "    gain_mode = agipd_cond.get_gain_mode()\n",
    "else:\n",
    "    gain_mode = AgipdGainMode(gain_mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mem_cells_db = mem_cells if mem_cells_db == -1 else mem_cells_db\n",
    "\n",
    "print(f\"Maximum memory cells to calibrate: {mem_cells}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Using {creation_time} as creation time\")\n",
    "print(\"Operating conditions are:\")\n",
    "print(f\"• Bias voltage: {bias_voltage}\")\n",
    "print(f\"• Memory cells: {mem_cells_db}\")\n",
    "print(f\"• Acquisition rate: {acq_rate}\")\n",
    "print(f\"• Gain setting: {gain_setting}\")\n",
    "print(f\"• Gain mode: {gain_mode.name}\")\n",
    "print(f\"• Integration time: {integration_time}\")\n",
    "print(f\"• Photon Energy: 9.2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if gain_mode:\n",
    "    for to_disable in disable_for_fixed_gain:\n",
    "        if corr_bools.get(to_disable, False):\n",
    "            warning(f\"{to_disable} correction was requested, but does not apply to fixed gain mode\")\n",
    "            corr_bools[to_disable] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if use_litframe_finder != 'off':\n",
    "    from extra_redu import make_litframe_finder, LitFrameFinderError\n",
    "    \n",
    "    if use_litframe_finder not in ['auto', 'offline', 'online']:\n",
    "        raise ValueError(\"Unexpected value in 'use_litframe_finder'.\")\n",
    "\n",
    "    inst = karabo_id_control[:3]\n",
    "    litfrm = make_litframe_finder(inst, dc, litframe_device_id)\n",
    "    try:\n",
    "        get_data = {'auto': litfrm.read_or_process, 'offline': litfrm.process, 'online': litfrm.read}\n",
    "        r = get_data[use_litframe_finder]()\n",
    "        cell_sel = LitFrameSelection(r, train_ids, max_pulses, energy_threshold, use_super_selection)\n",
    "        print(f\"{cell_sel.msg()}\\n\")\n",
    "        cell_sel.print_report()\n",
    "        if np.count_nonzero(r.nLitFrame.value) == 0:  # No lit frames.\n",
    "            latex_warning(\n",
    "                \"No lit frames identified using AGIPD LitFrameFinder.\"\n",
    "                \" Offline correction will not be performed.\")\n",
    "            sys.exit(0)\n",
    "    except LitFrameFinderError as err:\n",
    "        print(cell_sel.msg())\n",
    "        warning(f\"Cannot use AgipdLitFrameFinder due to:\\n{err}\")\n",
    "        cell_sel = CellRange(max_pulses, max_cells=mem_cells)\n",
    "else:\n",
    "    # Use range selection\n",
    "    cell_sel = CellRange(max_pulses, max_cells=mem_cells)\n",
    "    print(cell_sel.msg())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "photon_energy_warn = None\n",
    "if photon_energy <= 0.0:\n",
    "    if use_xgm_device:\n",
    "        # Try to obtain photon energy from XGM device.\n",
    "        wavelength_data = dc[use_xgm_device, 'pulseEnergy.wavelengthUsed']\n",
    "\n",
    "        try:\n",
    "            from scipy.constants import h, c, e\n",
    "\n",
    "            # Read wavelength as a single value and convert to hv.\n",
    "            photon_energy = (h * c / e) / (wavelength_data.as_single_value(rtol=1e-2) * 1e-6)\n",
    "            print(f'Obtained photon energy {photon_energy:.3f} keV from {use_xgm_device}')\n",
    "        except ValueError:\n",
    "            photon_energy_warn = 'XGM source available but photon energy varies greater than 1%'\n",
    "    else:\n",
    "        photon_energy_warn = 'Neither explicit photon energy nor XGM device configured'\n",
    "\n",
    "rounding_threshold_warn = None\n",
    "if rounding_threshold <= .0 or 1. <= rounding_threshold:\n",
    "    rounding_threshold_warn = 'Round threshold is out of (0, 1) range. Use standard 0.5 value.'\n",
    "    rounding_threshold = 0.5\n",
    "\n",
    "if round_photons:\n",
    "    if photon_energy_warn:\n",
    "        warning(photon_energy_warn + ', photon rounding disabled!')\n",
    "        round_photons = False\n",
    "    else:\n",
    "        print(f'Photon energy for rounding: {photon_energy:.3f} keV')\n",
    "        if rounding_threshold_warn:\n",
    "            warning(rounding_threshold_warn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if count_lit_pixels:\n",
    "    if round_photons:\n",
    "        data_units = 'photon'\n",
    "        litpx_threshold = 1.\n",
    "    else:\n",
    "        data_units = 'keV'\n",
    "        if photon_energy_warn:\n",
    "            warning(photon_energy_warn + '. Use 12 keV for lit-pixel counter threshold')\n",
    "            litpx_threshold = 12.\n",
    "        else:\n",
    "            litpx_threshold = photon_energy\n",
    "\n",
    "        if rounding_threshold_warn:\n",
    "            warning(rounding_threshold_warn)\n",
    "\n",
    "    if not xray_gain:\n",
    "        # convert photon energy to ADU (for AGIPD approx. 1 keV = 7 ADU)\n",
    "        # it looks that rounding to photons can be applied to data in ADU as well\n",
    "        litpx_threshold *= 7.\n",
    "        data_units = 'ADU'\n",
    "\n",
    "    litpx_threshold *= rounding_threshold\n",
    "    print(f\"Count lit-pixels with signal above {litpx_threshold:.3g} {data_units}\")\n",
    "else:\n",
    "    # dummy value, that is not expected to be used\n",
    "    litpx_threshold = 42"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agipd_corr = AgipdCorrections(\n",
    "    mem_cells,\n",
    "    cell_sel,\n",
    "    h5_data_path=instrument_src,\n",
    "    h5_index_path=index_src,\n",
    "    corr_bools=corr_bools,\n",
    "    gain_mode=gain_mode,\n",
    "    comp_threads=os.cpu_count() // n_cores_files,\n",
    "    train_ids=train_ids\n",
    ")\n",
    "\n",
    "agipd_corr.baseline_corr_noise_threshold = -blc_noise_threshold\n",
    "agipd_corr.hg_hard_threshold = hg_hard_threshold\n",
    "agipd_corr.mg_hard_threshold = mg_hard_threshold\n",
    "\n",
    "agipd_corr.cm_dark_min = cm_dark_range[0]\n",
    "agipd_corr.cm_dark_max = cm_dark_range[1]\n",
    "agipd_corr.cm_dark_fraction = cm_dark_fraction\n",
    "agipd_corr.cm_n_itr = cm_n_itr\n",
    "agipd_corr.noisy_adc_threshold = noisy_adc_threshold\n",
    "agipd_corr.ff_gain = ff_gain\n",
    "agipd_corr.photon_energy = photon_energy\n",
    "agipd_corr.rounding_threshold = rounding_threshold\n",
    "agipd_corr.cs_mg_adjust = cs_mg_adjust\n",
    "agipd_corr.litpx_threshold = litpx_threshold\n",
    "\n",
    "agipd_corr.compress_fields = compress_fields\n",
    "if recast_image_data:\n",
    "    agipd_corr.recast_image_fields['data'] = np.dtype(recast_image_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Retrieving constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_constants_and_update_metadata(cal_data, main_metadata, constants):\n",
    "    try:\n",
    "        metadata = cal_data.metadata(constants)\n",
    "        for key, value in metadata.items():\n",
    "            main_metadata.setdefault(key, {}).update(value)\n",
    "    except CalCatError as e:  # TODO: replace when API errors are improved.\n",
    "        warning(f\"CalCatError: {e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "step_timer.start()\n",
    "# Instantiate agipd_cal with the read operating conditions.\n",
    "agipd_cal = AGIPD_CalibrationData(\n",
    "    detector_name=karabo_id,\n",
    "    modules=karabo_da,\n",
    "    sensor_bias_voltage=bias_voltage,\n",
    "    memory_cells=mem_cells,\n",
    "    acquisition_rate=acq_rate,\n",
    "    integration_time=integration_time,\n",
    "    source_energy=9.2,\n",
    "    gain_mode=gain_mode,\n",
    "    gain_setting=gain_setting,\n",
    "    event_at=creation_time,\n",
    "    client=rest_cfg.calibration_client(),\n",
    "    caldb_root=Path(cal_db_root) if cal_db_root else None,\n",
    ")\n",
    "\n",
    "# Prepare lists of expected calibrations\n",
    "dark_constants = [\"Offset\", \"Noise\", \"BadPixelsDark\"]\n",
    "if not gain_mode:  # Adaptive gain\n",
    "    dark_constants.append(\"ThresholdsDark\")\n",
    "\n",
    "agipd_metadata = agipd_cal.metadata(dark_constants)\n",
    "\n",
    "agipd_cal.gain_mode = None  # gain_mode is not used for gain constants\n",
    "pc_constants, ff_constants, cs_constants = [], [], []\n",
    "\n",
    "if agipd_corr.corr_bools.get('xray_corr'):\n",
    "    ff_constants = list(agipd_cal.illuminated_calibrations)\n",
    "    get_constants_and_update_metadata(\n",
    "        agipd_cal, agipd_metadata, ff_constants)\n",
    "\n",
    "if any(agipd_corr.relgain_bools):\n",
    "\n",
    "    if cs_corr:\n",
    "        # Integration time is not used with CS\n",
    "        agipd_cal.integration_time = None\n",
    "        cs_constants = [\"SlopesCS\", \"BadPixelsCS\"]\n",
    "        get_constants_and_update_metadata(\n",
    "            agipd_cal, agipd_metadata, cs_constants)\n",
    "        \n",
    "\n",
    "    else:  # rel_gain_mode == \"pc\" or \"off\"\n",
    "        pc_constants = [\"SlopesPC\", \"BadPixelsPC\"]\n",
    "        get_constants_and_update_metadata(\n",
    "            agipd_cal, agipd_metadata, pc_constants)\n",
    "\n",
    "step_timer.done_step(\"Constants were retrieved in\")\n",
    "\n",
    "relgain_alias = \"CS\" if cs_corr else \"PC\"\n",
    "print(\"Preparing constants (\"\n",
    "      f\"FF: {agipd_corr.corr_bools.get('xray_corr', False)}, \"\n",
    "      f\"{relgain_alias}: {any(agipd_corr.relgain_bools)}, \"\n",
    "      f\"BLC: {any(agipd_corr.blc_bools)})\")\n",
    "# Display retrieved calibration constants timestamps\n",
    "agipd_cal.display_markdown_retrieved_constants(metadata=agipd_metadata)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Validate constants availability and exclude modules with no offsets.\n",
    "for da, calibrations in agipd_metadata.items():\n",
    "    mod = modules[karabo_da.index(da)]\n",
    "    # Constants to error out for when missing.\n",
    "    error_missing_constants = {\"Offset\"}\n",
    "    if not gain_mode:\n",
    "        error_missing_constants |= {\"ThresholdsDark\"}\n",
    "\n",
    "    error_missing_constants -= set(calibrations)\n",
    "    if error_missing_constants:\n",
    "        warning(f\"Offset constant is not available to correct {da}.\")\n",
    "        # Remove module from files to process.\n",
    "        del mapped_files[module_index_to_qm(mod)]\n",
    "        karabo_da.remove(da)\n",
    "        modules.remove(mod)\n",
    "\n",
    "    warn_missing_constants = set(dark_constants + pc_constants + ff_constants + cs_constants)\n",
    "    warn_missing_constants -= error_missing_constants\n",
    "    warn_missing_constants -= set(calibrations)\n",
    "    if warn_missing_constants:\n",
    "        warning(f\"Constants {warn_missing_constants} were not retrieved for {da}.\")\n",
    "\n",
    "if not mapped_files:  # Offsets are missing for all modules.\n",
    "    raise Exception(\"Could not find offset constants for any modules, will not correct data.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Record constant details in YAML metadata\n",
    "write_constants_fragment(\n",
    "    out_folder=(metadata_folder or out_folder),\n",
    "    det_metadata=agipd_metadata,\n",
    "    caldb_root=agipd_cal.caldb_root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load calibration constants to RAM\n",
    "agipd_corr.allocate_constants(modules, (3, mem_cells_db, 512, 128))\n",
    "\n",
    "def load_constants(da, module):\n",
    "    \"\"\"\n",
    "    Initialize constants data from previously retrieved metadata.\n",
    "\n",
    "    Args:\n",
    "        da (str): Data Aggregator (Karabo DA)\n",
    "        module (int): Module index\n",
    "\n",
    "    Returns:\n",
    "        (int, dict, str): Module index, {constant name: creation time}, Karabo DA\n",
    "    \"\"\"\n",
    "    const_data = dict()\n",
    "    variant = dict()\n",
    "    for cname, mdata in agipd_metadata[da].items():\n",
    "        dataset = mdata[\"dataset\"]\n",
    "        with h5py.File(agipd_cal.caldb_root / mdata[\"path\"], \"r\") as cf:  # noqa\n",
    "            const_data[cname] = np.copy(cf[f\"{dataset}/data\"])\n",
    "            variant[cname] = cf[dataset].attrs[\"variant\"] if cf[dataset].attrs.keys() else 0  # noqa\n",
    "    agipd_corr.init_constants(const_data, module, variant)\n",
    "\n",
    "\n",
    "step_timer.start()\n",
    "with multiprocessing.Pool(processes=len(modules)) as pool:\n",
    "    pool.starmap(load_constants, zip(karabo_da, modules))\n",
    "step_timer.done_step(f'Constants were loaded in ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Store timestamps for Offset, SlopesPC/SlopesCS, and SlopesFF\n",
    "# in YAML file for time-summary table.\n",
    "timestamps = {}\n",
    "\n",
    "for mod, mod_mdata in agipd_metadata.items():\n",
    "    modno = int(mod[-2:])\n",
    "\n",
    "    module_timestamps = {}\n",
    "\n",
    "    # Store few time stamps if exists\n",
    "    # Add NA to keep array structure\n",
    "    for key in ['Offset', f'Slopes{relgain_alias}', 'SlopesFF']:\n",
    "        if key in mod_mdata:\n",
    "            module_timestamps[key] = mod_mdata[key][\"begin_validity_at\"]\n",
    "        else:\n",
    "            module_timestamps[key] = \"NA\"\n",
    "\n",
    "    timestamps[module_index_to_qm(modno)] = module_timestamps\n",
    "\n",
    "seq = sequences[0] if sequences else 0\n",
    "\n",
    "with open(f\"{out_folder}/retrieved_constants_s{seq}.yml\",\"w\") as fd:\n",
    "    yaml.safe_dump({\"time-summary\": {f\"S{seq}\": timestamps}}, fd)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data processing ##"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# allocate memory for images and hists\n",
    "n_images_max = mem_cells * 256\n",
    "data_shape = (n_images_max, 512, 128)\n",
    "agipd_corr.allocate_images(data_shape, n_cores_files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def batches(l, batch_size):\n",
    "    \"\"\"Group a list into batches of (up to) batch_size elements\"\"\"\n",
    "    start = 0\n",
    "    while start < len(l):\n",
    "        yield l[start:start + batch_size]\n",
    "        start += batch_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def imagewise_chunks(img_counts):\n",
    "    \"\"\"Break up the loaded data into chunks of up to chunk_size\n",
    "\n",
    "    Yields (file data slot, start index, stop index)\n",
    "    \"\"\"\n",
    "    for i_proc, n_img in enumerate(img_counts):\n",
    "        n_chunks = math.ceil(n_img / chunk_size)\n",
    "        for i in range(n_chunks):\n",
    "            yield i_proc, i * n_img // n_chunks, (i+1) * n_img // n_chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "step_timer.start()\n",
    "all_imgs_counts = []\n",
    "if max_tasks_per_worker == -1:\n",
    "    max_tasks_per_worker = None\n",
    "with multiprocessing.Pool(maxtasksperchild=max_tasks_per_worker) as pool:\n",
    "    step_timer.done_step('Started pool')\n",
    "    \n",
    "    for file_batch in batches(file_list, n_cores_files):\n",
    "        # TODO: Move some printed output to logging or similar\n",
    "        print(f\"Processing next {len(file_batch)} files\")\n",
    "        step_timer.start()\n",
    "        img_counts = pool.starmap(\n",
    "            agipd_corr.read_file,\n",
    "            zip(range(len(file_batch)), file_batch, [not common_mode]*len(file_batch))\n",
    "        )\n",
    "        step_timer.done_step(f'Loading data from files')\n",
    "\n",
    "        all_imgs_counts += img_counts\n",
    "        if not np.any(img_counts):\n",
    "            # Skip any further processing and output if there are no images to\n",
    "            # correct in this file.\n",
    "            continue\n",
    "\n",
    "        if mask_zero_std:\n",
    "            # Evaluate zero-data-std mask\n",
    "            pool.starmap(\n",
    "                agipd_corr.mask_zero_std, itertools.product(\n",
    "                    range(len(file_batch)),\n",
    "                    np.array_split(np.arange(agipd_corr.max_cells), n_cores_correct)\n",
    "                )\n",
    "            )\n",
    "            step_timer.done_step('Mask 0 std')\n",
    "\n",
    "        # Perform offset image-wise correction\n",
    "        pool.starmap(agipd_corr.offset_correction, imagewise_chunks(img_counts))\n",
    "        step_timer.done_step(\"Offset correction\")\n",
    "\n",
    "        if blc_noise or blc_stripes or blc_hmatch:\n",
    "            # Perform image-wise correction\n",
    "            pool.starmap(agipd_corr.baseline_correction, imagewise_chunks(img_counts))\n",
    "            step_timer.done_step(\"Base-line shift correction\")\n",
    "\n",
    "        if common_mode:\n",
    "            # In common mode corrected is enabled.\n",
    "            # Cell selection is only activated after common mode correction.\n",
    "            # Perform cross-file correction parallel over asics\n",
    "            image_files_idx = [i_proc for i_proc, n_img in enumerate(img_counts) if n_img > 0]\n",
    "            pool.starmap(agipd_corr.cm_correction, itertools.product(\n",
    "                image_files_idx, range(16)  # 16 ASICs per module\n",
    "            ))\n",
    "            step_timer.done_step(\"Common-mode correction\")\n",
    "\n",
    "            img_counts = pool.map(agipd_corr.apply_selected_pulses, image_files_idx)\n",
    "            step_timer.done_step(\"Applying selected cells after common mode correction\")\n",
    "            \n",
    "        # Perform image-wise correction\"\n",
    "        pool.starmap(agipd_corr.gain_correction, imagewise_chunks(img_counts))\n",
    "        step_timer.done_step(\"Gain corrections\")\n",
    "        \n",
    "        # Peform additional processing\n",
    "        if count_lit_pixels:\n",
    "            pool.starmap(agipd_corr.litpixel_counter, imagewise_chunks(img_counts))\n",
    "            step_timer.done_step(\"Lit-pixel counting\")\n",
    "\n",
    "        # Save corrected data\n",
    "        pool.starmap(agipd_corr.write_file, [\n",
    "            (i_proc, file_name, str(out_folder / Path(file_name).name.replace(\"RAW\", \"CORR\")))\n",
    "            for i_proc, file_name in enumerate(file_batch)\n",
    "        ])\n",
    "        step_timer.done_step(\"Save\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Correction of {len(file_list)} files is finished\")\n",
    "print(f\"Total processing time {step_timer.timespan():.01f} s\")\n",
    "print(f\"Timing summary per batch of {n_cores_files} files:\")\n",
    "step_timer.print_summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if skip_plots:\n",
    "    print(\"Skipping plots as configured.\")\n",
    "    sys.exit(0)\n",
    "elif not np.any(all_imgs_counts):\n",
    "    latex_warning(f\"All sequence files contain no data for correction.\")\n",
    "    sys.exit(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def do_2d_plot(data, edges, y_axis, x_axis, title=\"\"):\n",
    "    fig = plt.figure(figsize=(10, 10))\n",
    "    ax = fig.add_subplot(111)\n",
    "    extent = np.array(\n",
    "        [np.nanmin(edges[1]), np.nanmax(edges[1]),\n",
    "         np.nanmin(edges[0]), np.nanmax(edges[0])])\n",
    "    im = ax.imshow(data[::-1, :], extent=extent, aspect=\"auto\",\n",
    "                   norm=LogNorm(vmin=1, vmax=max(10, np.max(data))))\n",
    "    ax.set_xlabel(x_axis)\n",
    "    ax.set_ylabel(y_axis)\n",
    "    ax.set_title(title)\n",
    "    cb = fig.colorbar(im)\n",
    "    cb.set_label(\"Counts\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_trains_data(data_folder, source, include, detector_id, tid=None, modules=16, fillvalue=None, mod_starts_at=0):\n",
    "    \"\"\"Load single train for all module\n",
    "\n",
    "    :param data_folder: Path to folder with data\n",
    "    :param source: Data source to be loaded\n",
    "    :param include: Inset of file name to be considered\n",
    "    :param detector_id: The karabo id of the detector to get data for\n",
    "    :param tid: Train Id to be loaded. First train is considered if None is given\n",
    "    :param path: Path to find image data inside h5 file\n",
    "    \"\"\"\n",
    "    try:\n",
    "        run_data = RunDirectory(data_folder, include)\n",
    "    except FileNotFoundError:\n",
    "        warning(f'No corrected files for {include}. Skipping plots.')\n",
    "        sys.exit(0)\n",
    "    if tid is not None:\n",
    "        tid, data = run_data.select(\n",
    "            f'{detector_id}/DET/*', source).train_from_id(tid, keep_dims=True)\n",
    "    else:\n",
    "        # A first full trainId for all available modules is of interest.\n",
    "        tid, data = next(run_data.select(\n",
    "            f'{detector_id}/DET/*', source).trains(require_all=True, keep_dims=True))\n",
    "\n",
    "    stacked_data = stack_detector_data(\n",
    "        train=data, data=source, fillvalue=fillvalue, modules=modules,\n",
    "        starts_at=mod_starts_at,\n",
    "    )\n",
    "\n",
    "    return tid, stacked_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if \"AGIPD500K\" in karabo_id:\n",
    "    geom = AGIPD_500K2GGeometry.from_origin()\n",
    "elif \"AGIPD1M\" in karabo_id:\n",
    "    geom = AGIPD_1MGeometry.from_quad_positions(quad_pos=[\n",
    "        (-525, 625),\n",
    "        (-550, -10),\n",
    "        (520, -160),\n",
    "        (542.5, 475),\n",
    "    ])\n",
    "else:  # single module AGIPD detector\n",
    "    geom = agipd_single_module_geometry()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "include = '*S00000*' if sequences is None else f'*S{sequences[0]:05d}*'\n",
    "mod_starts_at = 0 if nmods > 1 else modules[0]  # TODO: use CALCAT metadata for the detector.\n",
    "tid, corrected = get_trains_data(out_folder, 'image.data', include, karabo_id, modules=nmods, mod_starts_at=mod_starts_at)\n",
    "\n",
    "_, gains = get_trains_data(out_folder, 'image.gain', include, karabo_id, tid, modules=nmods, mod_starts_at=mod_starts_at)\n",
    "_, mask = get_trains_data(out_folder, 'image.mask', include, karabo_id, tid, modules=nmods, mod_starts_at=mod_starts_at)\n",
    "_, blshift = get_trains_data(out_folder, 'image.blShift', include, karabo_id, tid, modules=nmods, mod_starts_at=mod_starts_at)\n",
    "_, cellId = get_trains_data(out_folder, 'image.cellId', include, karabo_id, tid, modules=nmods, mod_starts_at=mod_starts_at)\n",
    "_, pulseId = get_trains_data(out_folder, 'image.pulseId', include, karabo_id, tid, modules=nmods, fillvalue=0, mod_starts_at=mod_starts_at)\n",
    "_, raw = get_trains_data(run_folder, 'image.data', include, karabo_id, tid, modules=nmods, mod_starts_at=mod_starts_at)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(f'## Preview and statistics for {gains.shape[0]} images of the train {tid} ##\\n'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# As part of data reduction efforts, the DAQ now has an option to discard AGIPD gain data\n",
    "# when it is known that all data is in the same gain stage. In such cases, the gain data\n",
    "# will be set to zeros. Consequently, the signal vs. analog gain 2D histogram can be skipped.\n",
    "gain = raw[:, 1, ...]\n",
    "if gain.max() > 0:\n",
    "    signal = raw[:, 0, ...]\n",
    "    display(Markdown(\"### Signal vs. Analogue Gain\"))\n",
    "    hist, bins_x, bins_y = calgs.histogram2d(\n",
    "        signal.flatten().astype(np.float32),\n",
    "        gain.flatten().astype(np.float32),\n",
    "        bins=(100, 100),\n",
    "        range=[\n",
    "            np.percentile(signal, [0.02, 99.8]),\n",
    "            np.percentile(gain, [0.02, 99.8]),\n",
    "            ],\n",
    "    )\n",
    "    do_2d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Analogue gain (ADU)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Signal vs. Digitized Gain ###\n",
    "\n",
    "The following plot shows plots signal vs. digitized gain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vmin, vmax = np.nanmin(corrected), np.nanmax(corrected)\n",
    "hist, bins_x, bins_y = calgs.histogram2d(\n",
    "    corrected.flatten().astype(np.float32),\n",
    "    gains.flatten().astype(np.float32), bins=(100, 3),\n",
    "    range=[\n",
    "        # The range boundaries and decided by DET expert.\n",
    "        [max(vmin, -50), min(vmax, 8192)],\n",
    "        [0, 3]\n",
    "        ],\n",
    "    )\n",
    "do_2d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Gain bit value\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Gain statistics in %\")\n",
    "table = [[f'{gains[gains==0].size/gains.size*100:.02f}',\n",
    "          f'{gains[gains==1].size/gains.size*100:.03f}',\n",
    "          f'{gains[gains==2].size/gains.size*100:.03f}']]\n",
    "md = display(Latex(tabulate.tabulate(table, tablefmt='latex',\n",
    "                                     headers=[\"High\", \"Medium\", \"Low\"])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Intensity per Pulse ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pulse_range = [np.nanmin(pulseId[pulseId>=0]), np.nanmax(pulseId[pulseId>=0])]\n",
    "\n",
    "\n",
    "def clamp(value, min_value, max_value):\n",
    "    return max(min_value, min(value, max_value))\n",
    "\n",
    "\n",
    "# Modify pulse_range, if only one pulse is selected.\n",
    "if pulse_range[0] == pulse_range[1]:\n",
    "    pulse_range = [0, pulse_range[1]+int(acq_rate)]\n",
    "\n",
    "mean_data = np.nanmean(corrected, axis=(2, 3))\n",
    "vmin, vmax = np.nanmin(mean_data), np.nanmax(mean_data)\n",
    "hist, bins_x, bins_y = calgs.histogram2d(\n",
    "    mean_data.flatten().astype(np.float32),\n",
    "    pulseId.flatten().astype(np.float32),\n",
    "    bins=(100, int(pulse_range[1])),\n",
    "    range=[[clamp(vmin, -50, -0.2), min(vmax, 1000)], pulse_range],\n",
    ")\n",
    "do_2d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Pulse id\", title=\"Signal-Pulse ID\")\n",
    "\n",
    "if vmax > 1000:  # a zoom out plot.\n",
    "    hist, bins_x, bins_y = calgs.histogram2d(\n",
    "        mean_data.flatten().astype(np.float32),\n",
    "        pulseId.flatten().astype(np.float32),\n",
    "        bins=(100,  int(pulse_range[1])),\n",
    "        range=[[clamp(vmin, -50, -0.2), min(vmax, 20000)], pulse_range]\n",
    "    )\n",
    "    do_2d_plot(hist, (bins_x, bins_y), \"Signal (ADU)\", \"Pulse id\", title=\"Signal-Pulse ID (Extended View)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Baseline shift ###\n",
    "\n",
    "Estimated base-line shift with respect to the total ADU counts of corrected image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "h = ax.hist(blshift.flatten(), bins=100, log=True)\n",
    "_ = plt.xlabel('Baseline shift [ADU]')\n",
    "_ = plt.ylabel('Counts')\n",
    "_ = ax.grid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 10))\n",
    "corrected_ave = np.nansum(corrected, axis=(2, 3))\n",
    "plt.scatter(corrected_ave.flatten()/10**6, blshift.flatten(), s=0.9)\n",
    "plt.xlim(np.nanpercentile(corrected_ave/10**6, [2, 98]))\n",
    "plt.grid()\n",
    "plt.xlabel('Illuminated corrected [MADU] ')\n",
    "_ = plt.ylabel('Estimated baseline shift [ADU]')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if cell_id_preview not in cellId[:, 0]:\n",
    "    print(f\"WARNING: The selected cell_id_preview value {cell_id_preview} is not available in the corrected data.\")\n",
    "    cell_id_preview = cellId[:, 0][0]\n",
    "    cell_idx_preview = 0\n",
    "    print(f\"Previewing the first available cellId: {cell_id_preview}.\")\n",
    "else:\n",
    "    cell_idx_preview = np.where(cellId[:, 0] == cell_id_preview)[0][0] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown('### Raw preview ###\\n'))\n",
    "if cellId.shape[0] != 1:\n",
    "    display(Markdown(f'Mean over images of the RAW data\\n'))\n",
    "    fig = plt.figure(figsize=(20, 10))\n",
    "    ax = fig.add_subplot(111)\n",
    "    data = np.mean(raw[slice(*cell_sel.crange), 0, ...], axis=0)\n",
    "    vmin, vmax = np.percentile(data, [5, 95])\n",
    "    ax = geom.plot_data_fast(data, ax=ax, vmin=vmin, vmax=vmax, cmap=cmap)\n",
    "    pass\n",
    "else:\n",
    "    print(\"Skipping mean RAW preview for single memory cell, \"\n",
    "          f\"see single shot image for selected cell ID {cell_id_preview}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(f'Single shot of the RAW data from cell {cell_id_preview} \\n'))\n",
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "vmin, vmax = np.percentile(raw[cell_idx_preview, 0, ...], [5, 95])\n",
    "ax = geom.plot_data_fast(\n",
    "    raw[cell_idx_preview, 0, ...], ax=ax, vmin=vmin, vmax=vmax, cmap=cmap)\n",
    "pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown('### Corrected preview ###\\n'))\n",
    "if cellId.shape[0] != 1:\n",
    "    display(Markdown('### Mean CORRECTED Preview ###\\n'))\n",
    "    display(Markdown(f'A mean across train: {tid}\\n'))\n",
    "    fig = plt.figure(figsize=(20, 10))\n",
    "    ax = fig.add_subplot(111)\n",
    "    data = np.mean(corrected, axis=0)\n",
    "    vmax = np.nanpercentile(data, 99.8)\n",
    "    ax = geom.plot_data_fast(data, ax=ax, vmin=max(-50, np.nanmin(data)), vmax=vmax, cmap=cmap)\n",
    "    pass\n",
    "else:\n",
    "    print(\"Skipping mean CORRECTED preview for single memory cell, \"\n",
    "          f\"see single shot image for selected cell ID {cell_id_preview}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(f'A single shot of the CORRECTED image from cell {cell_id_preview} \\n'))\n",
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "vmax = np.nanpercentile(corrected[cell_idx_preview], 99.8)\n",
    "ax = geom.plot_data_fast(\n",
    "    corrected[cell_idx_preview],\n",
    "    ax=ax,\n",
    "    vmin=max(-50, np.nanmin(corrected[cell_idx_preview])),\n",
    "    vmax=vmax,\n",
    "    cmap=cmap,\n",
    ")\n",
    "pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "vmin, vmax = get_range(corrected[cell_idx_preview], 5, -50)\n",
    "nbins = int((vmax + 50) / 2)\n",
    "h = ax.hist(corrected[cell_idx_preview].flatten(),\n",
    "            bins=nbins, range=(-50, vmax),\n",
    "            histtype='stepfilled', log=True)\n",
    "plt.xlabel('[ADU]')\n",
    "plt.ylabel('Counts')\n",
    "ax.grid()\n",
    "plt.title(f'Log-scaled histogram for corrected data for cell {cell_idx_preview}')\n",
    "pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "vmin, vmax = get_range(corrected, 10, -100)\n",
    "vmax = np.nanmax(corrected)\n",
    "if vmax > 50000:\n",
    "    vmax = 50000\n",
    "nbins = int((vmax + 100) / 5)\n",
    "h = ax.hist(corrected.flatten(), bins=nbins,\n",
    "            range=(-100, vmax), histtype='step', log=True, label = 'All')\n",
    "ax.hist(corrected[gains == 0].flatten(), bins=nbins, range=(-100, vmax),\n",
    "        alpha=0.5, log=True, label='High gain', color='green')\n",
    "ax.hist(corrected[gains == 1].flatten(), bins=nbins, range=(-100, vmax),\n",
    "        alpha=0.5, log=True, label='Medium gain', color='red')\n",
    "ax.hist(corrected[gains == 2].flatten(), bins=nbins, range=(-100, vmax),\n",
    "        alpha=0.5, log=True, label='Low gain', color='yellow')\n",
    "ax.legend()\n",
    "ax.grid()\n",
    "plt.xlabel('[ADU]')\n",
    "plt.ylabel('Counts')\n",
    "plt.title(f'Overlaid Histograms for corrected data for multiple gains')\n",
    "pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown('### Maximum GAIN Preview ###\\n'))\n",
    "display(Markdown(f'The per pixel maximum across one train for the digitized gain'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "ax = geom.plot_data_fast(\n",
    "    np.max(gains, axis=0), ax=ax,\n",
    "    cmap=cmap, vmin=-0.3, vmax=2.3)  # Extend cmap for wrong gain values.\n",
    "pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Bad Pixels ##\n",
    "The mask contains dedicated entries for all pixels and memory cells as well as all three gains stages. Each mask entry is encoded in 32 bits as:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "table = []\n",
    "for item in BadPixels:\n",
    "    table.append((item.name, \"{:016b}\".format(item.value)))\n",
    "md = display(Latex(tabulate.tabulate(table, tablefmt='latex',\n",
    "                                     headers=[\"Bad pixel type\", \"Bit mask\"])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(f'### Single Shot Bad Pixels ### \\n'))\n",
    "display(Markdown(f'A single shot bad pixel map from cell {cell_id_preview} \\n'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "geom.plot_data_fast(\n",
    "    np.log2(mask[cell_idx_preview]), ax=ax, vmin=0, vmax=32, cmap=cmap)\n",
    "pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if round_photons:\n",
    "    display(Markdown('### Photonization histograms ###'))\n",
    "    \n",
    "    x_preround = (agipd_corr.hist_bins_preround[1:] + agipd_corr.hist_bins_preround[:-1]) / 2\n",
    "    x_postround = (agipd_corr.hist_bins_postround[1:] + agipd_corr.hist_bins_postround[:-1]) / 2\n",
    "    x_photons = np.arange(0, (x_postround[-1] + 1) / photon_energy)\n",
    "\n",
    "    fig, ax = plt.subplots(ncols=1, nrows=1, clear=True)\n",
    "    ax.plot(x_preround, agipd_corr.shared_hist_preround, '.-', color='C0')\n",
    "    ax.bar(x_postround, agipd_corr.shared_hist_postround, photon_energy, color='C1', alpha=0.5)\n",
    "    ax.set_yscale('log')\n",
    "    ax.set_ylim(0, max(agipd_corr.shared_hist_preround.max(), agipd_corr.shared_hist_postround.max())*3)\n",
    "    ax.set_xlim(x_postround[0], x_postround[-1]+1)\n",
    "    ax.set_xlabel('Photon energy / keV')\n",
    "    ax.set_ylabel('Intensity')\n",
    "    ax.vlines(x_photons * photon_energy, *ax.get_ylim(), color='k', linestyle='dashed')\n",
    "\n",
    "    phx = ax.twiny()\n",
    "    phx.set_xlim(x_postround[0] / photon_energy, (x_postround[-1]+1)/photon_energy)\n",
    "    phx.set_xticks(x_photons)\n",
    "    phx.set_xlabel('# Photons')\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Percentage of Bad Pixels across one train  ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "geom.plot_data_fast(\n",
    "    np.mean(mask>0, axis=0), vmin=0, ax=ax, vmax=1, cmap=cmap)\n",
    "pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Percentage of Bad Pixels across one train. Only Dark Related ###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(111)\n",
    "cm = np.copy(mask)\n",
    "cm[cm > BadPixels.NO_DARK_DATA.value] = 0\n",
    "ax = geom.plot_data_fast(\n",
    "    np.mean(cm>0, axis=0), vmin=0, ax=ax, vmax=1, cmap=cmap)\n",
    "pass"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}