{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# LPD Offline Correction #\n",
    "\n",
    "Author: European XFEL Data Analysis Group"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-03T15:19:56.056417Z",
     "start_time": "2018-12-03T15:19:56.003012Z"
    }
   },
   "outputs": [],
   "source": [
    "# Input parameters\n",
    "in_folder = \"/gpfs/exfel/exp/FXE/202201/p003073/raw/\"  # the folder to read data from, required\n",
    "out_folder = \"/gpfs/exfel/data/scratch/schmidtp/random/LPD_test\"  # the folder to output to, required\n",
    "metadata_folder = ''  # Directory containing calibration_metadata.yml when run by xfel-calibrate.\n",
    "sequences = [-1]  # Sequences to correct, use [-1] for all\n",
    "modules = [-1]  # Modules indices to correct, use [-1] for all, only used when karabo_da is empty\n",
    "karabo_da = ['']  # Data aggregators names to correct, use [''] for all\n",
    "run = 10  # run to process, required\n",
    "\n",
    "# Source parameters\n",
    "karabo_id = 'FXE_DET_LPD1M-1'  # Karabo domain for detector.\n",
    "input_source = '{karabo_id}/DET/{module_index}CH0:xtdf'  # Input fast data source.\n",
    "output_source = ''  # Output fast data source, empty to use same as input.\n",
    "xgm_source = 'SA1_XTD2_XGM/DOOCS/MAIN'\n",
    "xgm_pulse_count_key = 'pulseEnergy.numberOfSa1BunchesActual'\n",
    "\n",
    "# CalCat parameters\n",
    "creation_time = \"\"  # The timestamp to use with Calibration DB. Required Format: \"YYYY-MM-DD hh:mm:ss\" e.g. 2019-07-04 11:02:41\n",
    "cal_db_interface = ''  # Not needed, compatibility with current webservice.\n",
    "cal_db_timeout = 0  # Not needed, compatbility with current webservice.\n",
    "cal_db_root = '/gpfs/exfel/d/cal/caldb_store'  # The calibration database root path to access constant files. For example accessing constants from the test database.\n",
    "\n",
    "# Operating conditions\n",
    "mem_cells = 512  # Memory cells, LPD constants are always taken with 512 cells.\n",
    "bias_voltage = 250.0  # Detector bias voltage.\n",
    "capacitor = '5pF'  # Capacitor setting: 5pF or 50pF\n",
    "photon_energy = 9.2  # Photon energy in keV.\n",
    "category = 0  # Whom to blame.\n",
    "use_cell_order = 'auto'  # Whether to use memory cell order as a detector condition; auto/always/never\n",
    "\n",
    "# Correction parameters\n",
    "offset_corr = True  # Offset correction.\n",
    "rel_gain = True  # Gain correction based on RelativeGain constant.\n",
    "ff_map = True  # Gain correction based on FFMap constant.\n",
    "gain_amp_map = True  # Gain correction based on GainAmpMap constant.\n",
    "\n",
    "# Output options\n",
    "ignore_no_frames_no_pulses = False  # Whether to run without SA1 pulses AND frames.\n",
    "overwrite = True  # set to True if existing data should be overwritten\n",
    "chunks_data = 1  # HDF chunk size for pixel data in number of frames.\n",
    "chunks_ids = 32  # HDF chunk size for cellId and pulseId datasets.\n",
    "create_virtual_cxi_in = ''  # Folder to create virtual CXI files in (for each sequence).\n",
    "\n",
    "# Parallelization options\n",
    "sequences_per_node = 1  # Sequence files to process per node\n",
    "max_nodes = 8  # Maximum number of SLURM jobs to split correction work into\n",
    "num_workers = 8  # Worker processes per node, 8 is safe on 768G nodes but won't work on 512G.\n",
    "num_threads_per_worker = 32  # Number of threads per worker.\n",
    "\n",
    "def balance_sequences(in_folder, run, sequences, sequences_per_node, karabo_da, max_nodes):\n",
    "    from xfel_calibrate.calibrate import balance_sequences as bs\n",
    "    return bs(in_folder, run, sequences, sequences_per_node, karabo_da, max_nodes=max_nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-12-03T15:19:56.990566Z",
     "start_time": "2018-12-03T15:19:56.058378Z"
    }
   },
   "outputs": [],
   "source": [
    "from logging import warning\n",
    "from pathlib import Path\n",
    "from time import perf_counter\n",
    "import gc\n",
    "import re\n",
    "\n",
    "import numpy as np\n",
    "import h5py\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use('agg')\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import extra_data as xd\n",
    "import extra_geom as xg\n",
    "import pasha as psh\n",
    "from extra_data.components import LPD1M\n",
    "\n",
    "import cal_tools.restful_config as rest_cfg\n",
    "from cal_tools.calcat_interface import CalCatError, LPD_CalibrationData\n",
    "from cal_tools.lpdalgs import correct_lpd_frames\n",
    "from cal_tools.lpdlib import get_mem_cell_pattern, make_cell_order_condition\n",
    "from cal_tools.tools import (\n",
    "    CalibrationMetadata,\n",
    "    calcat_creation_time,\n",
    "    write_constants_fragment,\n",
    ")\n",
    "from cal_tools.files import DataFile"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_re = re.compile(r'^RAW-R(\\d{4})-(\\w+\\d+)-S(\\d{5})$')  # This should probably move to cal_tools\n",
    "\n",
    "run_folder = Path(in_folder) / f'r{run:04d}'\n",
    "out_folder = Path(out_folder)\n",
    "out_folder.mkdir(exist_ok=True)\n",
    "\n",
    "output_source = output_source or input_source\n",
    "\n",
    "creation_time = calcat_creation_time(in_folder, run, creation_time)\n",
    "print(f'Using {creation_time.isoformat()} as creation time')\n",
    "\n",
    "# Pick all modules/aggregators or those selected.\n",
    "if karabo_da == ['']:\n",
    "    if modules == [-1]:\n",
    "        modules = list(range(16))\n",
    "    karabo_da = [f'LPD{i:02d}' for i in modules]\n",
    "else:\n",
    "    modules = [int(x[-2:]) for x in karabo_da]\n",
    "    \n",
    "# Pick all sequences or those selected.\n",
    "if not sequences or sequences == [-1]:\n",
    "    do_sequence = lambda seq: True\n",
    "else:\n",
    "    do_sequence = [int(x) for x in sequences].__contains__    \n",
    "    \n",
    "# List of detector sources.\n",
    "det_inp_sources = [input_source.format(karabo_id=karabo_id, module_index=int(da[-2:])) for da in karabo_da]\n",
    "\n",
    "if use_cell_order not in {'auto', 'always', 'never'}:\n",
    "    raise ValueError(\"use_cell_order must be auto/always/never\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Select data to process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_to_process = []\n",
    "\n",
    "for inp_path in run_folder.glob('RAW-*.h5'):\n",
    "    match = file_re.match(inp_path.stem)\n",
    "    \n",
    "    if match[2] not in karabo_da or not do_sequence(int(match[3])):\n",
    "        continue\n",
    "        \n",
    "    outp_path = out_folder / 'CORR-R{run:04d}-{aggregator}-S{seq:05d}.h5'.format(\n",
    "        run=int(match[1]), aggregator=match[2], seq=int(match[3]))\n",
    "\n",
    "    data_to_process.append((match[2], inp_path, outp_path))\n",
    "\n",
    "print('Files to process:')\n",
    "for data_descr in sorted(data_to_process, key=lambda x: f'{x[0]}{x[1]}'):\n",
    "    print(f'{data_descr[0]}\\t{data_descr[1]}')\n",
    "    \n",
    "# Collect the train ID contained in the input LPD files.\n",
    "inp_lpd_dc = xd.DataCollection.from_paths([x[1] for x in data_to_process])\n",
    "\n",
    "frame_count = sum([\n",
    "    int(inp_lpd_dc[source, 'image.data'].data_counts(labelled=False).sum())\n",
    "    for source in inp_lpd_dc.all_sources], 0)\n",
    "\n",
    "if frame_count == 0:\n",
    "    inp_dc = xd.RunDirectory(run_folder) \\\n",
    "        .select_trains(xd.by_id[inp_lpd_dc.train_ids])\n",
    "    \n",
    "    try:\n",
    "        pulse_count = int(inp_dc[xgm_source, xgm_pulse_count_key].ndarray().sum())\n",
    "    except xd.SourceNameError:\n",
    "        warning(f'Missing XGM source `{xgm_source}`')\n",
    "        pulse_count = None\n",
    "    except xd.PropertyNameError:\n",
    "        warning(f'Missing XGM pulse count key `{xgm_pulse_count_key}`')\n",
    "        pulse_count = None\n",
    "    \n",
    "    if pulse_count == 0 and not ignore_no_frames_no_pulses:\n",
    "        warning(f'Affected files contain neither LPD frames nor SA1 pulses '\n",
    "                f'according to {xgm_source}, processing is skipped. If this '\n",
    "                f'incorrect, please contact da-support@xfel.eu')\n",
    "        from sys import exit\n",
    "        exit(0)\n",
    "    elif pulse_count is None:\n",
    "        raise ValueError('Affected files contain no LPD frames and SA1 pulses '\n",
    "                         'could not be inferred from XGM data')\n",
    "    else:\n",
    "        raise ValueError('Affected files contain no LPD frames but SA1 pulses')\n",
    "        \n",
    "else:\n",
    "    print(f'Total number of LPD pulses across all modules: {frame_count}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Obtain and prepare calibration constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "start = perf_counter()\n",
    "\n",
    "cell_ids_pattern_s = None\n",
    "if use_cell_order != 'never':\n",
    "    # Read the order of memory cells used\n",
    "    raw_data = xd.DataCollection.from_paths([e[1] for e in data_to_process])\n",
    "    cell_ids_pattern_s = make_cell_order_condition(\n",
    "        use_cell_order, get_mem_cell_pattern(raw_data, det_inp_sources)\n",
    "    )\n",
    "print(\"Memory cells order:\", cell_ids_pattern_s)\n",
    "\n",
    "lpd_cal = LPD_CalibrationData(\n",
    "    detector_name=karabo_id,\n",
    "    modules=karabo_da,\n",
    "    sensor_bias_voltage=bias_voltage,\n",
    "    memory_cells=mem_cells,\n",
    "    feedback_capacitor=capacitor,\n",
    "    source_energy=photon_energy,\n",
    "    memory_cell_order=cell_ids_pattern_s,\n",
    "    category=category,\n",
    "    event_at=creation_time,\n",
    "    client=rest_cfg.calibration_client(),\n",
    "    caldb_root=Path(cal_db_root),\n",
    ")\n",
    "\n",
    "lpd_metadata = lpd_cal.metadata([\"Offset\", \"BadPixelsDark\"])\n",
    "try:\n",
    "    illum_metadata = lpd_cal.metadata(lpd_cal.illuminated_calibrations)\n",
    "    for key, value in illum_metadata.items():\n",
    "        lpd_metadata.setdefault(key, {}).update(value)\n",
    "except CalCatError as e:  # TODO: replace when API errors are improved.\n",
    "    warning(f\"CalCatError: {e}\")\n",
    "\n",
    "total_time = perf_counter() - start\n",
    "print(f'Looking up constants {total_time:.1f}s')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Validate the constants availability and raise/warn accordingly.\n",
    "for mod, calibrations in lpd_metadata.items():\n",
    "    missing_offset = {\"Offset\"} - set(calibrations)\n",
    "    warn_missing_constants = {\n",
    "        \"BadPixelsDark\", \"BadPixelsFF\", \"GainAmpMap\",\n",
    "        \"FFMap\", \"RelativeGain\"} - set(calibrations)\n",
    "    if missing_offset:\n",
    "        warning(f\"Offset constant is not available to correct {mod}.\")\n",
    "        karabo_da.remove(mod)\n",
    "    if warn_missing_constants:\n",
    "        warning(f\"Constants {warn_missing_constants} were not retrieved for {mod}.\")\n",
    "if not karabo_da:  # Offsets are missing for all modules.\n",
    "    raise Exception(\"Could not find offset constants for any modules, will not correct data.\")\n",
    "\n",
    "# Remove skipped correction modules from data_to_process\n",
    "data_to_process = [(mod, in_f, out_f) for mod, in_f, out_f in data_to_process if mod in karabo_da]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# write constants metadata to fragment YAML\n",
    "write_constants_fragment(\n",
    "    out_folder=(metadata_folder or out_folder),\n",
    "    det_metadata=lpd_metadata,\n",
    "    caldb_root=lpd_cal.caldb_root,\n",
    ")\n",
    "\n",
    "# Load constants data for all constants\n",
    "const_data = lpd_cal.ndarray_map(metadata=lpd_metadata)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# These are intended in order cell, X, Y, gain\n",
    "ccv_offsets = {}\n",
    "ccv_gains = {}\n",
    "ccv_masks = {}\n",
    "\n",
    "ccv_shape = (mem_cells, 256, 256, 3)\n",
    "\n",
    "constant_order = {\n",
    "    'Offset':        (2, 1, 0, 3),\n",
    "    'BadPixelsDark': (2, 1, 0, 3),\n",
    "    'RelativeGain':  (2, 0, 1, 3),\n",
    "    'FFMap':         (2, 0, 1, 3),\n",
    "    'BadPixelsFF':   (2, 0, 1, 3),\n",
    "    'GainAmpMap':    (2, 0, 1, 3),\n",
    "}\n",
    "\n",
    "def prepare_constants(wid, index, aggregator):\n",
    "    consts = const_data.get(aggregator, {})\n",
    "    def _prepare_data(calibration_name, dtype):\n",
    "        # Some old BadPixels constants have <f8 dtype.\n",
    "        # Convert nan to float 0 to avoid having 2147483648 after\n",
    "        # converting float64 to uint32.\n",
    "        if \"BadPixels\" in calibration_name and consts[calibration_name].dtype != np.uint32:\n",
    "            consts[calibration_name] = np.nan_to_num(\n",
    "                consts[calibration_name], nan=0.0)\n",
    "        return consts[calibration_name] \\\n",
    "            .transpose(constant_order[calibration_name]) \\\n",
    "            .astype(dtype, copy=True)  # Make sure array is contiguous.\n",
    "    \n",
    "    if offset_corr and 'Offset' in consts:\n",
    "        ccv_offsets[aggregator] = _prepare_data('Offset', np.float32)\n",
    "    else:\n",
    "        ccv_offsets[aggregator] = np.zeros(ccv_shape, dtype=np.float32)\n",
    "        \n",
    "    ccv_gains[aggregator] = np.ones(ccv_shape, dtype=np.float32)\n",
    "    \n",
    "    if 'BadPixelsDark' in consts:\n",
    "        ccv_masks[aggregator] = _prepare_data('BadPixelsDark', np.uint32)\n",
    "    else:\n",
    "        ccv_masks[aggregator] = np.zeros(ccv_shape, dtype=np.uint32)\n",
    "    \n",
    "    if rel_gain and 'RelativeGain' in consts:\n",
    "        ccv_gains[aggregator] *= _prepare_data('RelativeGain', np.float32)\n",
    "        \n",
    "    if ff_map and 'FFMap' in consts:\n",
    "        ccv_gains[aggregator] *= _prepare_data('FFMap', np.float32)\n",
    "        \n",
    "        if 'BadPixelsFF' in consts:\n",
    "            np.bitwise_or(ccv_masks[aggregator], _prepare_data('BadPixelsFF', np.uint32),\n",
    "                          out=ccv_masks[aggregator])\n",
    "        \n",
    "    if gain_amp_map and 'GainAmpMap' in consts:\n",
    "        ccv_gains[aggregator] *= _prepare_data('GainAmpMap', np.float32)\n",
    "        \n",
    "    print('.', end='', flush=True)\n",
    "    \n",
    "\n",
    "print('Preparing constants', end='', flush=True)\n",
    "start = perf_counter()\n",
    "psh.ThreadContext(num_workers=len(karabo_da)).map(prepare_constants, karabo_da)\n",
    "total_time = perf_counter() - start\n",
    "print(f'{total_time:.1f}s')\n",
    "\n",
    "const_data.clear()  # Clear raw constants data now to save memory.\n",
    "gc.collect();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def correct_file(wid, index, work):\n",
    "    aggregator, inp_path, outp_path = work\n",
    "    module_index = int(aggregator[-2:])\n",
    "    \n",
    "    start = perf_counter()\n",
    "    dc = xd.H5File(inp_path, inc_suspect_trains=False).select('*', 'image.*', require_all=True)\n",
    "    inp_source = dc[input_source.format(karabo_id=karabo_id, module_index=module_index)]\n",
    "    open_time = perf_counter() - start\n",
    "    \n",
    "    # Load raw data for this file.\n",
    "    # Reshaping gets rid of the extra 1-len dimensions without\n",
    "    # mangling the frame axis for an actual frame count of 1.\n",
    "    start = perf_counter()\n",
    "    in_raw = inp_source['image.data'].ndarray().reshape(-1, 256, 256)\n",
    "    in_cell = inp_source['image.cellId'].ndarray().reshape(-1)\n",
    "    in_pulse = inp_source['image.pulseId'].ndarray().reshape(-1)\n",
    "    read_time = perf_counter() - start\n",
    "    \n",
    "    # Allocate output arrays.\n",
    "    out_data = np.zeros((in_raw.shape[0], 256, 256), dtype=np.float32)\n",
    "    out_gain = np.zeros((in_raw.shape[0], 256, 256), dtype=np.uint8)\n",
    "    out_mask = np.zeros((in_raw.shape[0], 256, 256), dtype=np.uint32)\n",
    "            \n",
    "    start = perf_counter()\n",
    "    correct_lpd_frames(in_raw, in_cell,\n",
    "                       out_data, out_gain, out_mask,\n",
    "                       ccv_offsets[aggregator], ccv_gains[aggregator], ccv_masks[aggregator],\n",
    "                       num_threads=num_threads_per_worker)\n",
    "    correct_time = perf_counter() - start\n",
    "    \n",
    "    image_counts = inp_source['image.data'].data_counts(labelled=False)\n",
    "    \n",
    "    start = perf_counter()\n",
    "    if (not outp_path.exists() or overwrite) and image_counts.sum() > 0:\n",
    "        outp_source_name = output_source.format(karabo_id=karabo_id, module_index=module_index)\n",
    "\n",
    "        with DataFile(outp_path, 'w') as outp_file:            \n",
    "            outp_file.create_index(dc.train_ids, from_file=dc.files[0])\n",
    "            outp_file.create_metadata(like=dc, instrument_channels=(f'{outp_source_name}/image',))\n",
    "            \n",
    "            outp_source = outp_file.create_instrument_source(outp_source_name)\n",
    "            \n",
    "            outp_source.create_index(image=image_counts)\n",
    "            outp_source.create_key('image.cellId', data=in_cell,\n",
    "                                   chunks=(min(chunks_ids, in_cell.shape[0]),))\n",
    "            outp_source.create_key('image.pulseId', data=in_pulse,\n",
    "                                   chunks=(min(chunks_ids, in_pulse.shape[0]),))\n",
    "            outp_source.create_key('image.data', data=out_data,\n",
    "                                   chunks=(min(chunks_data, out_data.shape[0]), 256, 256))\n",
    "            outp_source.create_compressed_key('image.gain', data=out_gain)\n",
    "            outp_source.create_compressed_key('image.mask', data=out_mask)\n",
    "    write_time = perf_counter() - start\n",
    "    \n",
    "    total_time = open_time + read_time + correct_time + write_time\n",
    "    frame_rate = in_raw.shape[0] / total_time\n",
    "    \n",
    "    print('{}\\t{}\\t{:.3f}\\t{:.3f}\\t{:.3f}\\t{:.3f}\\t{:.3f}\\t{}\\t{:.1f}'.format(\n",
    "        wid, aggregator, open_time, read_time, correct_time, write_time, total_time,\n",
    "        in_raw.shape[0], frame_rate))\n",
    "    \n",
    "    in_raw = None\n",
    "    in_cell = None\n",
    "    in_pulse = None\n",
    "    out_data = None\n",
    "    out_gain = None\n",
    "    out_mask = None\n",
    "    gc.collect()\n",
    "\n",
    "print('worker\\tDA\\topen\\tread\\tcorrect\\twrite\\ttotal\\tframes\\trate')\n",
    "start = perf_counter()\n",
    "psh.ProcessContext(num_workers=num_workers).map(correct_file, data_to_process)\n",
    "total_time = perf_counter() - start\n",
    "print(f'Total time: {total_time:.1f}s')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data preview for first train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geom = xg.LPD_1MGeometry.from_quad_positions(\n",
    "    [(11.4, 299), (-11.5, 8), (254.5, -16), (278.5, 275)])\n",
    "\n",
    "output_paths = [outp_path for _, _, outp_path in data_to_process if outp_path.exists()]\n",
    "\n",
    "if not output_paths:\n",
    "    warning('Data preview is skipped as there are no existing output paths')\n",
    "    from sys import exit\n",
    "    exit(0)\n",
    "\n",
    "dc = xd.DataCollection.from_paths(output_paths).select_trains(np.s_[0])\n",
    "\n",
    "det = LPD1M(dc, detector_name=karabo_id)\n",
    "data = det.get_array('image.data')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Intensity histogram across all cells"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "left_edge_ratio = 0.01\n",
    "right_edge_ratio = 0.99\n",
    "\n",
    "fig, ax = plt.subplots(num=1, clear=True, figsize=(15, 6))\n",
    "values, bins, _ = ax.hist(np.ravel(data.data), bins=2000, range=(-1500, 2000))\n",
    "\n",
    "def find_nearest_index(array, value):\n",
    "    return (np.abs(array - value)).argmin()\n",
    "\n",
    "cum_values = np.cumsum(values)\n",
    "vmin = bins[find_nearest_index(cum_values, cum_values[-1]*left_edge_ratio)]\n",
    "vmax = bins[find_nearest_index(cum_values, cum_values[-1]*right_edge_ratio)]\n",
    "\n",
    "max_value = values.max()\n",
    "ax.vlines([vmin, vmax], 0, max_value, color='red', linewidth=5, alpha=0.2)\n",
    "ax.text(vmin, max_value, f'{left_edge_ratio*100:.0f}%',\n",
    "        color='red', ha='center', va='bottom', size='large')\n",
    "ax.text(vmax, max_value, f'{right_edge_ratio*100:.0f}%',\n",
    "        color='red', ha='center', va='bottom', size='large')\n",
    "ax.text(vmax+(vmax-vmin)*0.01, max_value/2, 'Colormap interval',\n",
    "        color='red', rotation=90, ha='left', va='center', size='x-large')\n",
    "\n",
    "ax.set_xlim(vmin-(vmax-vmin)*0.1, vmax+(vmax-vmin)*0.1)\n",
    "ax.set_ylim(0, max_value*1.1)\n",
    "pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### First memory cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(num=2, figsize=(15, 15), clear=True, nrows=1, ncols=1)\n",
    "geom.plot_data_fast(data[:, 0, 0], ax=ax, vmin=vmin, vmax=vmax)\n",
    "pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2018-11-13T18:24:57.547563Z",
     "start_time": "2018-11-13T18:24:56.995005Z"
    },
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(num=3, figsize=(15, 15), clear=True, nrows=1, ncols=1)\n",
    "geom.plot_data_fast(data[:, 0].mean(axis=1), ax=ax, vmin=vmin, vmax=vmax)\n",
    "pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Lowest gain stage per pixel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "highest_gain_stage = det.get_array('image.gain', pulses=np.s_[:]).max(axis=(1, 2))\n",
    "\n",
    "fig, ax = plt.subplots(num=4, figsize=(15, 15), clear=True, nrows=1, ncols=1)\n",
    "p = geom.plot_data_fast(highest_gain_stage, ax=ax, vmin=0, vmax=2);\n",
    "\n",
    "cb = ax.images[0].colorbar\n",
    "cb.set_ticks([0, 1, 2])\n",
    "cb.set_ticklabels(['High gain', 'Medium gain', 'Low gain'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create virtual CXI file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if create_virtual_cxi_in:\n",
    "    vcxi_folder = Path(create_virtual_cxi_in.format(\n",
    "        run=run, proposal_folder=str(Path(in_folder).parent)))\n",
    "    vcxi_folder.mkdir(parents=True, exist_ok=True)\n",
    "    \n",
    "    def sort_files_by_seq(by_seq, outp_path):\n",
    "        by_seq.setdefault(int(outp_path.stem[-5:]), []).append(outp_path)\n",
    "        return by_seq\n",
    "    \n",
    "    from functools import reduce\n",
    "    reduce(sort_files_by_seq, output_paths, output_by_seq := {})\n",
    "        \n",
    "    for seq_number, seq_output_paths in output_by_seq.items():\n",
    "        # Create data collection and detector components only for this sequence.\n",
    "        try:\n",
    "            det = LPD1M(xd.DataCollection.from_paths(seq_output_paths), detector_name=karabo_id, min_modules=4)\n",
    "        except ValueError:  # Couldn't find enough data for min_modules\n",
    "            continue\n",
    "        det.write_virtual_cxi(vcxi_folder / f'VCXI-LPD-R{run:04d}-S{seq_number:05d}.cxi')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pycal",
   "language": "python",
   "name": "pycal"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}