{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "bed7bd15-21d9-4735-82c1-c27c1a5e3346",
   "metadata": {},
   "source": [
    "# Gotthard2 Offline Correction\n",
    "\n",
    "Author: European XFEL Detector Group, Version: 1.0\n",
    "\n",
    "Offline Correction for Gotthard2 Detector.\n",
    "\n",
    "This notebook is able to correct 25um and 50um GH2 detectors using the same correction steps:\n",
    "- Convert 12bit raw data into 10bit, offset subtraction, then multiply with gain constant.\n",
    "\n",
    "| Correction | constants   | boolean to enable/disable   |\n",
    "|------------|-------------|-----------------------------|\n",
    "|   12bit to 10bit  | `LUTGotthard2` |  |\n",
    "|   Offset  | `OffsetGotthard2`|`offset_correction`|\n",
    "|   Relative gain  | `RelativeGainGotthard2` + `BadPixelsFFGotthard2` |`gain_correction`|\n",
    "\n",
    "Beside the corrected data, a mask is stored using the badpixels constant of the same parameter conditions and time.\n",
    "- `BadPixelsDarkGotthard2`\n",
    "- `BadPixelsFFGotthard2`, if relative gain correction is requested.\n",
    "\n",
    "The correction is done per sequence file. If all selected sequence files have no images to correct the notebook will fail.\n",
    "The same result would be reached in case the needed dark calibration constants were not retrieved for all modules and `offset_correction` is True.\n",
    "In case one of the gain constants were not retrieved `gain_correction` is switched to False and gain correction is disabled.\n",
    "\n",
    "The `data` datasets stored in the RECEIVER source along with the corrected image (`adc`) and `mask` are:\n",
    "\n",
    "  - `gain`\n",
    "\n",
    "  - `bunchId`\n",
    "\n",
    "  - `memoryCell`\n",
    "\n",
    "  - `frameNumber`\n",
    "\n",
    "  - `timestamp`\n",
    "\n",
    "  - `trainId`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "570322ed-f611-4fd1-b2ec-c12c13d55843",
   "metadata": {},
   "outputs": [],
   "source": [
    "in_folder = \"/gpfs/exfel/exp/DETLAB/202330/p900326/raw\"  # the folder to read data from, required\n",
    "out_folder = \"/gpfs/exfel/data/scratch/ahmedk/test/gotthard2\"  # the folder to output to, required\n",
    "metadata_folder = \"\"  # Directory containing calibration_metadata.yml when run by xfel-calibrate\n",
    "run = 20  # run to process, required\n",
    "sequences = [-1]  # sequences to correct, set to [-1] for all, range allowed\n",
    "sequences_per_node = 1  # number of sequence files per node if notebook executed through xfel-calibrate, set to 0 to not run SLURM parallel\n",
    "\n",
    "# Parameters used to access raw data.\n",
    "karabo_id = \"DETLAB_25UM_GH2\"  # karabo prefix of Gotthard-II devices\n",
    "karabo_da = [\"\"]  # data aggregators\n",
    "receiver_template = \"RECEIVER{}\"  # receiver template used to read INSTRUMENT keys.\n",
    "receiver_affixes = [\"\"]  # The affix to format into the receiver template to be able to load the correct receiver name from the data.\n",
    "control_template = \"CONTROL\"  # control template used to read CONTROL keys.\n",
    "ctrl_source_template = \"{}/DET/{}\"  # template for control source name (filled with karabo_id_control)\n",
    "karabo_id_control = \"\"  # Control karabo ID. Set to empty string to use the karabo-id\n",
    "corr_source_template = \"{}/CORR/{}:daqOutput\"  # Correction data source template. filled with karabo_id and correction receiver\n",
    "corr_receiver = \"\"  # The receiver name of the corrected data. Leave empty for using the same receiver name for the 50um GH2 or the first(Master) receiver for the 25um GH2.\n",
    "\n",
    "# Parameters for calibration database.\n",
    "cal_db_interface = \"tcp://max-exfl-cal001:8016#8025\"  # the database interface to use.\n",
    "cal_db_timeout = 180000  # timeout on caldb requests.\n",
    "creation_time = \"\"  # To overwrite the measured creation_time. Required Format: YYYY-MM-DD HR:MN:SC e.g. \"2022-06-28 13:00:00\"\n",
    "\n",
    "# Parameters affecting corrected data.\n",
    "constants_file = \"\"  # Use constants in given constant file path. /gpfs/exfel/data/scratch/ahmedk/dont_remove/gotthard2/constants/calibration_constants_GH2.h5\n",
    "offset_correction = True  # apply offset correction. This can be disabled to only apply LUT or apply LUT and gain correction for non-linear differential results.\n",
    "gain_correction = True  # apply gain correction.\n",
    "chunks_data = 1  # HDF chunk size for pixel data in number of frames.\n",
    "\n",
    "# Parameter conditions.\n",
    "bias_voltage = -1  # Detector bias voltage, set to -1 to use value in raw file.\n",
    "exposure_time = -1.  # Detector exposure time, set to -1 to use value in raw file.\n",
    "exposure_period = -1.  # Detector exposure period, set to -1 to use value in raw file.\n",
    "acquisition_rate = -1.  # Detector acquisition rate (1.1/4.5), set to -1 to use value in raw file.\n",
    "single_photon = -1  # Detector single photon mode (High/Low CDS), set to -1 to use value in raw file.\n",
    "reverse_second_module = -1  # Reverse 25um GH2 second module before interleaving. set to -1 to use value in raw file to reverse based on `CTRL/reverseSlaveReadOutMode`'s value.\n",
    "\n",
    "# Parameters for plotting\n",
    "skip_plots = False  # exit after writing corrected files\n",
    "pulse_idx_preview = 3  # pulse index to preview. The following even/odd pulse index is used for preview. # TODO: update to pulseId preview.\n",
    "\n",
    "\n",
    "def balance_sequences(in_folder, run, sequences, sequences_per_node, karabo_da):\n",
    "    from xfel_calibrate.calibrate import balance_sequences as bs\n",
    "    return bs(in_folder, run, sequences, sequences_per_node, karabo_da)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e9730d8-3908-41d7-abe2-d78e046d5de2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "from logging import warning\n",
    "\n",
    "import h5py\n",
    "import pasha as psh\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import Markdown, display\n",
    "from extra_data import RunDirectory, H5File\n",
    "from pathlib import Path\n",
    "\n",
    "import cal_tools.restful_config as rest_cfg\n",
    "from cal_tools.calcat_interface import CalCatError, GOTTHARD2_CalibrationData\n",
    "from cal_tools.files import DataFile\n",
    "from cal_tools.gotthard2 import gotthard2algs, gotthard2lib\n",
    "from cal_tools.step_timing import StepTimer\n",
    "from cal_tools.tools import (\n",
    "    calcat_creation_time,\n",
    "    map_seq_files,\n",
    "    write_constants_fragment,\n",
    ")\n",
    "from XFELDetAna.plotting.heatmap import heatmapPlot\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7c02c48-4429-42ea-a42e-de45366d7fa3",
   "metadata": {},
   "outputs": [],
   "source": [
    "in_folder = Path(in_folder)\n",
    "run_folder = in_folder / f\"r{run:04d}\"\n",
    "out_folder = Path(out_folder)\n",
    "out_folder.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "if not karabo_id_control:\n",
    "    karabo_id_control = karabo_id\n",
    "\n",
    "ctrl_src = ctrl_source_template.format(karabo_id_control, control_template)\n",
    "\n",
    "# Run's creation time:\n",
    "creation_time = calcat_creation_time(in_folder, run, creation_time)\n",
    "print(f\"Creation time: {creation_time}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9a8d1eb-ce6a-4ed0-abf4-4a6029734672",
   "metadata": {},
   "outputs": [],
   "source": [
    "step_timer = StepTimer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "892172d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "run_dc = RunDirectory(run_folder)\n",
    "\n",
    "# Read slow data\n",
    "g2ctrl = gotthard2lib.Gotthard2Ctrl(run_dc=run_dc, ctrl_src=ctrl_src)\n",
    "\n",
    "if bias_voltage == -1:\n",
    "    bias_voltage = g2ctrl.get_bias_voltage()\n",
    "if exposure_time == -1:\n",
    "    exposure_time = g2ctrl.get_exposure_time()\n",
    "if exposure_period == -1:\n",
    "    exposure_period = g2ctrl.get_exposure_period()\n",
    "if acquisition_rate == -1:\n",
    "    acquisition_rate = g2ctrl.get_acquisition_rate()\n",
    "if single_photon == -1:\n",
    "    single_photon = g2ctrl.get_single_photon()\n",
    "\n",
    "gh2_detector = g2ctrl.get_det_type()\n",
    "if reverse_second_module == -1 and gh2_detector == \"25um\":\n",
    "    reverse_second_module = not g2ctrl.second_module_reversed()\n",
    "    print(\n",
    "        \"Second module is not reversed. \"\n",
    "        \"Reversing second module before interleaving.\")\n",
    "print(\"Bias Voltage:\", bias_voltage)\n",
    "print(\"Exposure Time:\", exposure_time)\n",
    "print(\"Exposure Period:\", exposure_period)\n",
    "print(\"Acquisition Rate:\", acquisition_rate)\n",
    "print(\"Single Photon:\", single_photon)\n",
    "print(f\"Processing {gh2_detector} Gotthard2.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21a8953a-8c76-475e-8f4f-b201cc25c159",
   "metadata": {},
   "outputs": [],
   "source": [
    "# GH2 calibration data object.\n",
    "g2_cal = GOTTHARD2_CalibrationData(\n",
    "    detector_name=karabo_id,\n",
    "    sensor_bias_voltage=bias_voltage,\n",
    "    exposure_time=exposure_time,\n",
    "    exposure_period=exposure_period,\n",
    "    acquisition_rate=acquisition_rate,\n",
    "    single_photon=single_photon,\n",
    "    event_at=creation_time,\n",
    "    client=rest_cfg.calibration_client(),\n",
    ")\n",
    "\n",
    "da_to_pdu = None\n",
    "# Keep as long as it is essential to correct\n",
    "# RAW data (FXE p003225) before the data mapping was added to CALCAT.\n",
    "try:  # in case local constants are used with old RAW data. This can be removed in the future.\n",
    "    da_to_pdu = g2_cal.mod_to_pdu\n",
    "except CalCatError as e:\n",
    "    print(e)\n",
    "    db_modules = [None] * len(karabo_da)\n",
    "\n",
    "if da_to_pdu:\n",
    "    if karabo_da == [\"\"]:\n",
    "        karabo_da = sorted(da_to_pdu.keys())\n",
    "    else:\n",
    "        # Exclude non selected DA from processing.\n",
    "        karabo_da = [da for da in karabo_da if da in da_to_pdu]\n",
    "\n",
    "    db_modules = [da_to_pdu[da] for da in karabo_da]\n",
    "\n",
    "print(f\"Process modules: {db_modules} for run {run}\")\n",
    "\n",
    "# Create the correction receiver name.\n",
    "receiver_names = [f\"*{receiver_template.format(x)}:daqOutput\" for x in receiver_affixes]\n",
    "data_sources = list(run_dc.select(receiver_names).all_sources)\n",
    "\n",
    "if not corr_receiver:\n",
    "    # This part assumes this data_source structure: '{karabo_id}/DET/{receiver_name}:{output_channel}'\n",
    "    if gh2_detector == \"25um\":  # For 25um use virtual karabo_das for CALCAT data mapping.\n",
    "        corr_receiver = data_sources[0].split(\"/\")[-1].split(\":\")[0]\n",
    "    else:\n",
    "        corr_receiver = data_sources[0].split(\"/\")[-1].split(\":\")[0]\n",
    "    print(f\"Using {corr_receiver} as a receiver name for the corrected data.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2551b923",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check the available trains to correct.\n",
    "total_trains = len(RunDirectory(run_folder).select(data_sources, require_all=True).train_ids)\n",
    "if total_trains:\n",
    "    print(f\"Correcting {total_trains}.\")\n",
    "else:\n",
    "    raise ValueError(f\"No trains to correct for run {run}.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c852392-bb19-4c40-b2ce-3b787538a92d",
   "metadata": {},
   "source": [
    "### Retrieving calibration constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5717d722",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Used for old FXE (p003225) runs before adding Gotthard2 to CALCAT\n",
    "const_data = dict()\n",
    "\n",
    "if constants_file:\n",
    "    for mod in karabo_da:\n",
    "        const_data[mod] = dict()\n",
    "        # load constants temporarily using defined local paths.\n",
    "        with h5py.File(constants_file, \"r\") as cfile:\n",
    "            const_data[mod][\"LUTGotthard2\"] = cfile[\"LUT\"][()]\n",
    "            const_data[mod][\"OffsetGotthard2\"] = cfile[\"offset_map\"][()].astype(np.float32)\n",
    "            const_data[mod][\"RelativeGainGotthard2\"] = cfile[\"gain_map\"][()].astype(np.float32)\n",
    "            const_data[mod][\"Mask\"] = cfile[\"bpix_ff\"][()].astype(np.uint32)\n",
    "else:\n",
    "    constant_names = [\"LUTGotthard2\", \"OffsetGotthard2\", \"BadPixelsDarkGotthard2\"]\n",
    "    if gain_correction:\n",
    "        constant_names += [\"RelativeGainGotthard2\", \"BadPixelsFFGotthard2\"]\n",
    "\n",
    "    g2_metadata = g2_cal.metadata(calibrations=constant_names)\n",
    "    # Display retrieved calibration constants timestamps\n",
    "    g2_cal.display_markdown_retrieved_constants(metadata=g2_metadata)\n",
    "\n",
    "    # Validate the constants availability and raise/warn correspondingly.\n",
    "    for mod, calibrations in g2_metadata.items():\n",
    "\n",
    "        dark_constants = {\"LUTGotthard2\"}\n",
    "        if offset_correction:\n",
    "            dark_constants |= {\"OffsetGotthard2\", \"BadPixelsDarkGotthard2\"}\n",
    "\n",
    "        missing_dark_constants = dark_constants - set(calibrations)\n",
    "        if missing_dark_constants:\n",
    "            karabo_da.remove(mod)\n",
    "            warning(f\"Dark constants {missing_dark_constants} are not available to correct {mod}.\")  # noqa\n",
    "\n",
    "        missing_gain_constants = {\n",
    "            \"BadPixelsFFGotthard2\", \"RelativeGainGotthard2\"} - set(calibrations)\n",
    "        if gain_correction and missing_gain_constants:\n",
    "            warning(f\"Gain constants {missing_gain_constants} are not retrieved for mod {mod}.\")\n",
    "\n",
    "if not karabo_da:\n",
    "    raise ValueError(\"Dark constants are not available for all modules.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac1cdec5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Record constant details in YAML metadata.\n",
    "write_constants_fragment(\n",
    "    out_folder=(metadata_folder or out_folder),\n",
    "    det_metadata=g2_metadata,\n",
    "    caldb_root=g2_cal.caldb_root)\n",
    "\n",
    "# Load constants data for all constants.\n",
    "const_data = g2_cal.ndarray_map(metadata=g2_metadata)\n",
    "\n",
    "# Prepare constant arrays.\n",
    "if not constants_file:\n",
    "    for mod in karabo_da:\n",
    "        # Create the mask array.\n",
    "        bpix = const_data[mod].get(\"BadPixelsDarkGotthard2\")\n",
    "        if bpix is None:\n",
    "            bpix = np.zeros((1280, 2, 3), dtype=np.uint32)\n",
    "        if const_data[mod].get(\"BadPixelsFFGotthard2\") is not None:\n",
    "            bpix |= const_data[mod][\"BadPixelsFFGotthard2\"]\n",
    "        const_data[mod][\"Mask\"] = bpix\n",
    "\n",
    "        # Prepare empty arrays for missing constants.\n",
    "        if const_data[mod].get(\"OffsetGotthard2\") is None:\n",
    "            const_data[mod][\"OffsetGotthard2\"] = np.zeros(\n",
    "                (1280, 2, 3), dtype=np.float32)\n",
    "\n",
    "        if const_data[mod].get(\"RelativeGainGotthard2\") is None:\n",
    "            const_data[mod][\"RelativeGainGotthard2\"] = np.ones(\n",
    "                (1280, 2, 3), dtype=np.float32)\n",
    "        const_data[mod][\"RelativeGainGotthard2\"] = const_data[mod][\"RelativeGainGotthard2\"].astype(  # noqa\n",
    "            np.float32, copy=False)  # Old gain constants are not float32."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c7dd0bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "file_da = list({kda.split('/')[0] for kda in karabo_da})\n",
    "mapped_files, total_files = map_seq_files(\n",
    "    run_folder,\n",
    "    file_da,\n",
    "    sequences,\n",
    ")\n",
    "# This notebook doesn't account for processing more\n",
    "# than one file data aggregator.\n",
    "seq_files = mapped_files[file_da[0]]\n",
    "\n",
    "if not len(seq_files):\n",
    "    raise IndexError(\n",
    "        \"No sequence files available to correct for the selected sequences and karabo_da.\")\n",
    "print(f\"Processing a total of {total_files} sequence files\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23fcf7f4-351a-4df7-8829-d8497d94fecc",
   "metadata": {},
   "outputs": [],
   "source": [
    "context = psh.ProcessContext(num_workers=23)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daecd662-26d2-4cb8-aa70-383a579cf9f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def correct_train(wid, index, d):\n",
    "    g = gain[index]\n",
    "    gotthard2algs.convert_to_10bit(d, const_data[mod][\"LUTGotthard2\"], data_corr[index, ...])\n",
    "    gotthard2algs.correct_train(\n",
    "        data_corr[index, ...],\n",
    "        mask[index, ...],\n",
    "        g,\n",
    "        const_data[mod][\"OffsetGotthard2\"].astype(np.float32),  # PSI map is in f8\n",
    "        const_data[mod][\"RelativeGainGotthard2\"],  \n",
    "        const_data[mod][\"Mask\"],\n",
    "        apply_offset=offset_correction,\n",
    "        apply_gain=gain_correction,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f88c1aa6-a735-4b72-adce-b30162f5daea",
   "metadata": {},
   "outputs": [],
   "source": [
    "corr_data_source = corr_source_template.format(karabo_id, corr_receiver)\n",
    "\n",
    "for raw_file in seq_files:\n",
    "\n",
    "    out_file = out_folder / raw_file.name.replace(\"RAW\", \"CORR\")\n",
    "    # Select module INSTRUMENT sources and deselect empty trains.\n",
    "    dc = H5File(raw_file).select(data_sources, require_all=True)\n",
    "    n_trains = len(dc.train_ids)\n",
    "\n",
    "    if n_trains == 0:\n",
    "        warning(f\"Skipping correction. No trains to correct for this sequence file: {raw_file}.\")\n",
    "        continue\n",
    "    else:\n",
    "        print(f\"Correcting {n_trains} for {raw_file}.\")\n",
    "\n",
    "    # Initialize GH2 data and gain arrays to store in corrected files.\n",
    "    if gh2_detector == \"25um\":\n",
    "        dshape_stored = (dc[data_sources[0], \"data.adc\"].shape[:2] + (1280 * 2,))\n",
    "        data_stored = np.zeros(dshape_stored, dtype=np.float32)\n",
    "        gain_stored = np.zeros(dshape_stored, dtype=np.uint8)\n",
    "        mask_stored = np.zeros(dshape_stored, dtype=np.uint32)\n",
    "    else:\n",
    "        data_stored, gain_stored, mask_stored = None, None, None\n",
    "\n",
    "    for i, (src, mod) in enumerate(zip(data_sources, karabo_da)):\n",
    "        step_timer.start()\n",
    "        print(f\"Correcting {src} for {raw_file}\")\n",
    "\n",
    "        data = dc[src, \"data.adc\"].ndarray()\n",
    "        gain = dc[src, \"data.gain\"].ndarray()\n",
    "        step_timer.done_step(\"Preparing raw data\")\n",
    "        dshape = data.shape\n",
    "\n",
    "        step_timer.start()\n",
    "\n",
    "        # Allocate shared arrays.\n",
    "        data_corr = context.alloc(shape=dshape, dtype=np.float32)\n",
    "        mask = context.alloc(shape=dshape, dtype=np.uint32)\n",
    "        context.map(correct_train, data)\n",
    "        step_timer.done_step(f\"Correcting one receiver in one sequence file\")\n",
    "\n",
    "        step_timer.start()\n",
    "\n",
    "        # Provided PSI gain map has 0 values. Set inf values to nan.\n",
    "        # TODO: This can maybe be removed after creating XFEL gain maps.?\n",
    "        data_corr[np.isinf(data_corr)] = np.nan\n",
    "\n",
    "        # Create CORR files and add corrected data sections.\n",
    "        image_counts = dc[src, \"data.adc\"].data_counts(labelled=False)\n",
    "        if reverse_second_module and i == 1:\n",
    "            data_corr = np.flip(data_corr, axis=-1)\n",
    "            gain = np.flip(gain, axis=-1)\n",
    "\n",
    "        if gh2_detector == \"25um\":\n",
    "            data_stored[..., i::2] = data_corr\n",
    "            gain_stored[..., i::2] = gain\n",
    "            mask_stored[..., i::2] = mask\n",
    "        else:  # \"50um\"\n",
    "            data_stored, gain_stored, mask_stored = data_corr, gain, mask\n",
    "\n",
    "    seq_file = dc.files[0]  # FileAccess\n",
    "    with DataFile(out_file, \"w\") as ofile:\n",
    "        # Create INDEX datasets.\n",
    "        ofile.create_index(dc.train_ids, from_file=seq_file)\n",
    "        ofile.create_metadata(\n",
    "            like=dc,\n",
    "            sequence=seq_file.sequence,\n",
    "            instrument_channels=(f\"{corr_data_source}/data\",)\n",
    "        )\n",
    "\n",
    "        # Create Instrument section to later add corrected datasets.\n",
    "        outp_source = ofile.create_instrument_source(corr_data_source)\n",
    "\n",
    "        # Create count/first datasets at INDEX source.\n",
    "        outp_source.create_index(data=image_counts)\n",
    "\n",
    "        # Store uncorrected trainId in the corrected file.\n",
    "        outp_source.create_key(\n",
    "                f\"data.trainId\", data=dc.train_ids,\n",
    "                chunks=min(50, len(dc.train_ids))\n",
    "            )\n",
    "\n",
    "        # Create datasets with the available corrected data\n",
    "        for field_name, field_data in {\n",
    "            \"adc\": data_stored,\n",
    "            \"gain\": gain_stored,\n",
    "        }.items():\n",
    "            outp_source.create_key(\n",
    "                f\"data.{field_name}\", data=field_data,\n",
    "                chunks=((chunks_data,) + data_corr.shape[1:])\n",
    "        )\n",
    "\n",
    "        # For GH2 25um, the data of the second receiver is\n",
    "        # stored in the corrected file.\n",
    "        for field in [\"bunchId\", \"memoryCell\", \"frameNumber\", \"timestamp\"]:\n",
    "                outp_source.create_key(\n",
    "                    f\"data.{field}\", data=dc[src, f\"data.{field}\"].ndarray(),\n",
    "                    chunks=(chunks_data, data_corr.shape[1])\n",
    "            )\n",
    "        outp_source.create_compressed_key(f\"data.mask\", data=mask_stored)\n",
    "        step_timer.done_step(\"Storing data\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94b8e4d2-9f8c-4c23-a509-39238dd8435c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Total processing time {step_timer.timespan():.01f} s\")\n",
    "step_timer.print_summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ccc7f7e-2a3f-4ac0-b854-7d505410d2fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "if skip_plots:\n",
    "    print(\"Skipping plots\")\n",
    "    import sys\n",
    "\n",
    "    sys.exit(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff203f77-3811-46f3-bf7d-226d2dcab13f",
   "metadata": {},
   "outputs": [],
   "source": [
    "mod_dcs = {}\n",
    "first_seq_raw = seq_files[0]\n",
    "first_seq_corr = out_folder / first_seq_raw.name.replace(\"RAW\", \"CORR\")\n",
    "mod_dcs[corr_data_source] = {}\n",
    "with H5File(first_seq_corr) as out_dc:\n",
    "    tid, mod_dcs[corr_data_source][\"train_corr_data\"] = next(\n",
    "        out_dc[corr_data_source, \"data.adc\"].trains()\n",
    "    )\n",
    "\n",
    "if gh2_detector == \"25um\":\n",
    "    mod_dcs[corr_data_source][\"train_raw_data\"] = np.zeros((data_corr.shape[1], 1280 * 2), dtype=np.float32)\n",
    "    mod_dcs[corr_data_source][\"train_raw_gain\"] = np.zeros((data_corr.shape[1], 1280 * 2), dtype=np.uint8)\n",
    "\n",
    "for i, src in enumerate(data_sources):\n",
    "    with H5File(first_seq_raw) as in_dc:\n",
    "        train_dict = in_dc.train_from_id(tid)[1][src]\n",
    "        if gh2_detector == \"25um\":\n",
    "            if reverse_second_module and i == 1:\n",
    "                data = np.flip(train_dict[\"data.adc\"], axis=-1)\n",
    "                gain = np.flip(train_dict[\"data.gain\"], axis=-1)\n",
    "            else:\n",
    "                data = train_dict[\"data.adc\"]\n",
    "                gain = train_dict[\"data.gain\"]\n",
    "            mod_dcs[corr_data_source][\"train_raw_data\"][..., i::2] = data\n",
    "            mod_dcs[corr_data_source][\"train_raw_gain\"][..., i::2] = gain\n",
    "        else:\n",
    "            mod_dcs[corr_data_source][\"train_raw_data\"] = train_dict[\"data.adc\"]\n",
    "            mod_dcs[corr_data_source][\"train_raw_gain\"] = train_dict[\"data.gain\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b379438-eb1d-42b2-ac83-eb8cf88c46db",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(\"### Mean RAW and CORRECTED across pulses for one train:\"))\n",
    "display(Markdown(f\"Train: {tid}\"))\n",
    "\n",
    "if gh2_detector == \"50um\":\n",
    "    title = f\"{{}} data for {karabo_da} ({db_modules})\"\n",
    "else:\n",
    "    title = f\"Interleaved {{}} data for {karabo_da} ({db_modules})\"\n",
    "\n",
    "step_timer.start()\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(15, 15))\n",
    "raw_data = mod_dcs[corr_data_source][\"train_raw_data\"]\n",
    "im = ax.plot(np.mean(raw_data, axis=0))\n",
    "ax.set_title(title.format(\"RAW\"), fontsize=20)\n",
    "ax.set_xlabel(\"Strip #\", size=20)\n",
    "ax.set_ylabel(\"12-bit ADC output\", size=20)\n",
    "plt.xticks(fontsize=20)\n",
    "plt.yticks(fontsize=20)\n",
    "pass\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(15, 15))\n",
    "corr_data = mod_dcs[corr_data_source][\"train_corr_data\"]\n",
    "im = ax.plot(np.mean(corr_data, axis=0))\n",
    "ax.set_title(title.format(\"CORRECTED\"), fontsize=20)\n",
    "ax.set_xlabel(\"Strip #\", size=20)\n",
    "ax.set_ylabel(\"10-bit KeV. output\", size=20)\n",
    "plt.xticks(fontsize=20)\n",
    "plt.yticks(fontsize=20)\n",
    "pass\n",
    "step_timer.done_step(\"Plotting mean data\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58a6a276",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(f\"### RAW and CORRECTED strips across pulses for train {tid}\"))\n",
    "\n",
    "step_timer.start()\n",
    "for plt_data, dname in zip(\n",
    "    [\"train_raw_data\", \"train_corr_data\"], [\"RAW\", \"CORRECTED\"]\n",
    "):\n",
    "    fig, ax = plt.subplots(figsize=(15, 15))\n",
    "    plt.rcParams.update({\"font.size\": 20})\n",
    "\n",
    "    heatmapPlot(\n",
    "        mod_dcs[corr_data_source][plt_data],\n",
    "        y_label=\"Pulses\",\n",
    "        x_label=\"Strips\",\n",
    "        title=title.format(dname),\n",
    "        use_axis=ax,\n",
    "        cb_pad=0.8,\n",
    "    )\n",
    "    pass\n",
    "step_timer.done_step(\"Plotting RAW and CORRECTED data for one train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd8f5e08-fcee-4bff-ba63-6452b3d892a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Validate given \"pulse_idx_preview\"\n",
    "\n",
    "if pulse_idx_preview + 1 > data.shape[1]:\n",
    "    print(\n",
    "        f\"WARNING: selected pulse_idx_preview {pulse_idx_preview} is not available in data.\"\n",
    "        \" Previewing 1st pulse.\"\n",
    "    )\n",
    "    pulse_idx_preview = 1\n",
    "\n",
    "if data.shape[1] == 1:\n",
    "    odd_pulse = 1\n",
    "    even_pulse = None\n",
    "else:\n",
    "    odd_pulse = pulse_idx_preview if pulse_idx_preview % 2 else pulse_idx_preview + 1\n",
    "    even_pulse = (\n",
    "        pulse_idx_preview if not (pulse_idx_preview % 2) else pulse_idx_preview + 1\n",
    "    )\n",
    "\n",
    "if pulse_idx_preview + 1 > data.shape[1]:\n",
    "    pulse_idx_preview = 1\n",
    "    if data.shape[1] > 1:\n",
    "        pulse_idx_preview = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5f0d4d8-e32c-4f2c-8469-4ebbfd3f644c",
   "metadata": {},
   "outputs": [],
   "source": [
    "display(Markdown(\"### RAW and CORRECTED even/odd pulses for one train:\"))\n",
    "display(Markdown(f\"Train: {tid}\"))\n",
    "fig, ax = plt.subplots(figsize=(15, 15))\n",
    "raw_data = mod_dcs[corr_data_source][\"train_raw_data\"]\n",
    "corr_data = mod_dcs[corr_data_source][\"train_corr_data\"]\n",
    "\n",
    "ax.plot(raw_data[odd_pulse], label=f\"Odd Pulse {odd_pulse}\")\n",
    "if even_pulse:\n",
    "    ax.plot(raw_data[even_pulse], label=f\"Even Pulse {even_pulse}\")\n",
    "\n",
    "ax.set_title(title.format(\"RAW\"), fontsize=20)\n",
    "ax.set_xlabel(\"Strip #\", size=20)\n",
    "ax.set_ylabel(\"12-bit ADC RAW\", size=20)\n",
    "plt.xticks(fontsize=20)\n",
    "plt.yticks(fontsize=20)\n",
    "ax.legend()\n",
    "pass\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(15, 15))\n",
    "ax.plot(corr_data[odd_pulse], label=f\"Odd Pulse {odd_pulse}\")\n",
    "if even_pulse:\n",
    "    ax.plot(corr_data[even_pulse], label=f\"Even Pulse {even_pulse}\")\n",
    "ax.set_title(title.format(\"CORRECTED\"), fontsize=20)\n",
    "ax.set_xlabel(\"Strip #\", size=20)\n",
    "ax.set_ylabel(\"10-bit KeV CORRECTED\", size=20)\n",
    "plt.xticks(fontsize=20)\n",
    "plt.yticks(fontsize=20)\n",
    "ax.legend()\n",
    "pass\n",
    "step_timer.done_step(\"Plotting RAW and CORRECTED odd/even pulses.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".cal_venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}