diff --git a/notebooks/LPD/LPD_Correct_Fast.ipynb b/notebooks/LPD/LPD_Correct_Fast.ipynb index 04d35400aafeaa9197d79ff6a11439d697c84df0..728b3817420190b4540a72b8a06c7f0e92167aea 100644 --- a/notebooks/LPD/LPD_Correct_Fast.ipynb +++ b/notebooks/LPD/LPD_Correct_Fast.ipynb @@ -33,6 +33,7 @@ "karabo_id = 'FXE_DET_LPD1M-1' # Karabo domain for detector.\n", "input_source = '{karabo_id}/DET/{module_index}CH0:xtdf' # Input fast data source.\n", "output_source = '{karabo_id}/CORR/{module_index}CH0:output' # Output fast data source, empty to use same as input.\n", + "control_source = '{karabo_id}/COMP/FEM_MDL_COMP' # Control data source.\n", "xgm_source = 'SA1_XTD2_XGM/DOOCS/MAIN'\n", "xgm_pulse_count_key = 'pulseEnergy.numberOfSa1BunchesActual'\n", "\n", @@ -55,6 +56,9 @@ "rel_gain = True # Gain correction based on RelativeGain constant.\n", "ff_map = True # Gain correction based on FFMap constant.\n", "gain_amp_map = True # Gain correction based on GainAmpMap constant.\n", + "combine_parallel_gain = True # Combine parallel gain images into a single frame.\n", + "threshold_sigma_high = 5.0 # Sigma level for threshold between high and medium gain.\n", + "threshold_sigma_mid = 100.0 # Sigma level for threshold between medium and low gain.\n", "\n", "# Output options\n", "ignore_no_frames_no_pulses = False # Whether to run without SA1 pulses AND frames.\n", @@ -85,9 +89,9 @@ }, "outputs": [], "source": [ - "from logging import warning\n", "from pathlib import Path\n", "from time import perf_counter\n", + "from warnings import warn\n", "import gc\n", "import re\n", "\n", @@ -204,16 +208,16 @@ " try:\n", " pulse_count = int(inp_dc[xgm_source, xgm_pulse_count_key].ndarray().sum())\n", " except xd.SourceNameError:\n", - " warning(f'Missing XGM source `{xgm_source}`')\n", + " warn(f'Missing XGM source `{xgm_source}`')\n", " pulse_count = None\n", " except xd.PropertyNameError:\n", - " warning(f'Missing XGM pulse count key `{xgm_pulse_count_key}`')\n", + " warn(f'Missing XGM pulse count key `{xgm_pulse_count_key}`')\n", " pulse_count = None\n", " \n", " if pulse_count == 0 and not ignore_no_frames_no_pulses:\n", - " warning(f'Affected files contain neither LPD frames nor SA1 pulses '\n", - " f'according to {xgm_source}, processing is skipped. If this '\n", - " f'incorrect, please contact da-support@xfel.eu')\n", + " warn(f'Affected files contain neither LPD frames nor SA1 pulses '\n", + " f'according to {xgm_source}, processing is skipped. If this '\n", + " f'incorrect, please contact da-support@xfel.eu')\n", " from sys import exit\n", " exit(0)\n", " elif pulse_count is None:\n", @@ -244,13 +248,24 @@ "source": [ "start = perf_counter()\n", "\n", + "raw_data = xd.RunDirectory(run_folder)\n", + "\n", + "try:\n", + " parallel_gain = bool(raw_data[control_source.format(karabo_id=karabo_id)].run_value('femAsicGainOverride'))\n", + "except KeyError:\n", + " warn('Missing femAsicGainOverride property FEM control device, assuming auto gain')\n", + " parallel_gain = False\n", + "print('Parallel gain mode:', parallel_gain)\n", + "\n", "cell_ids_pattern_s = None\n", "if use_cell_order != 'never':\n", + " mem_cell_pattern = get_mem_cell_pattern(raw_data, det_inp_sources)\n", + " \n", + " if parallel_gain:\n", + " mem_cell_pattern = mem_cell_pattern[:len(mem_cell_pattern) // 3]\n", + " \n", " # Read the order of memory cells used\n", - " raw_data = xd.DataCollection.from_paths([e[1] for e in data_to_process])\n", - " cell_ids_pattern_s = make_cell_order_condition(\n", - " use_cell_order, get_mem_cell_pattern(raw_data, det_inp_sources)\n", - " )\n", + " cell_ids_pattern_s = make_cell_order_condition(use_cell_order, mem_cell_pattern)\n", "print(\"Memory cells order:\", cell_ids_pattern_s)\n", "\n", "conditions = LPDConditions(\n", @@ -259,6 +274,7 @@ " feedback_capacitor=capacitor,\n", " source_energy=photon_energy,\n", " memory_cell_order=cell_ids_pattern_s,\n", + " parallel_gain=parallel_gain,\n", " category=category,\n", ")\n", "\n", @@ -269,6 +285,8 @@ " expected_constants.update(['FFMap', 'BadPixelsFF'])\n", "if gain_amp_map:\n", " expected_constants.add('GainAmpMap')\n", + "if parallel_gain and combine_parallel_gain:\n", + " expected_constants.add('Noise')\n", "\n", "lpd_consts = CalibrationData.from_condition(\n", " conditions,\n", @@ -300,13 +318,13 @@ " \n", "for mod in karabo_da.copy():\n", " if mod not in lpd_consts[\"Offset\"].aggregator_names:\n", - " warning(f\"Offset constant is not available to correct {mod}.\")\n", + " warn(f\"Offset constant is not available to correct {mod}.\")\n", " karabo_da.remove(mod)\n", " \n", " missing_constants = {c for c in expected_constants\n", " if (c not in lpd_consts) or (mod not in lpd_consts[c].aggregator_names)}\n", " if missing_constants:\n", - " warning(f\"Constants {sorted(missing_constants)} were not retrieved for {mod}.\")\n", + " warn(f\"Constants {sorted(missing_constants)} were not retrieved for {mod}.\")\n", "\n", "# Remove skipped correction modules from data_to_process\n", "data_to_process = [(mod, in_f, out_f) for mod, in_f, out_f in data_to_process if mod in karabo_da]" @@ -345,6 +363,7 @@ "source": [ "# These are intended in order cell, X, Y, gain\n", "ccv_offsets = {}\n", + "ccv_noise = {}\n", "ccv_gains = {}\n", "ccv_masks = {}\n", "\n", @@ -352,6 +371,7 @@ "\n", "constant_order = {\n", " 'Offset': (2, 1, 0, 3),\n", + " 'Noise': (2, 1, 0, 3),\n", " 'BadPixelsDark': (2, 1, 0, 3),\n", " 'RelativeGain': (2, 0, 1, 3),\n", " 'FFMap': (2, 0, 1, 3),\n", @@ -378,6 +398,14 @@ " ccv_offsets[aggregator] = np.zeros(ccv_shape, dtype=np.float32)\n", " \n", " ccv_gains[aggregator] = np.ones(ccv_shape, dtype=np.float32)\n", + "\n", + " if parallel_gain and combine_parallel_gain:\n", + " if 'Noise' in consts:\n", + " ccv_noise[aggregator] = _prepare_data('Noise', np.float32)\n", + " else:\n", + " raise RuntimeError('parallel gain combination requires available noise constant')\n", + " else:\n", + " ccv_noise[aggregator] = None\n", " \n", " if 'BadPixelsDark' in consts:\n", " ccv_masks[aggregator] = _prepare_data('BadPixelsDark', np.uint32)\n", @@ -396,15 +424,12 @@ " \n", " if 'GainAmpMap' in consts:\n", " ccv_gains[aggregator] *= _prepare_data('GainAmpMap', np.float32)\n", - " \n", - " print('.', end='', flush=True)\n", " \n", "\n", - "print('Preparing constants', end='', flush=True)\n", "start = perf_counter()\n", "psh.ThreadContext(num_workers=len(karabo_da)).map(prepare_constants, karabo_da)\n", "total_time = perf_counter() - start\n", - "print(f'{total_time:.1f}s')\n", + "print(f'Preparing constants {total_time:.1f}s')\n", "\n", "const_data.clear() # Clear raw constants data now to save memory.\n", "gc.collect();" @@ -416,6 +441,56 @@ "metadata": {}, "outputs": [], "source": [ + "def iter_count_slices(offset_counts, len_counts=None, step=None):\n", + " \"\"\"Generate slices to index another array based on counts.\n", + "\n", + " Given an array of counts C dividing another flat array A into\n", + " different parts such that C.sum() == A.size, this generates the\n", + " necessary slices to iterate over each part defined by C:\n", + "\n", + " ```\n", + " A = np.arange(15)\n", + " C = np.array([5, 5, 5])\n", + " list(iter_count_slices(C))\n", + " > [slice(0, 5, None), slice(5, 10, None), slice(10, 15, None)]\n", + " ```\n", + "\n", + " The counts used to compute the slice starts, i.e. the offsets\n", + " into A, can be chosen independently of the length of each slice:\n", + "\n", + " ```\n", + " list(iter_count_slices([15, 15, 15], [5, 5, 5]))\n", + " > [slice(15, 20, None), slice(30, 35, None), slice(45, 50, None)]\n", + " ```\n", + "\n", + " Args:\n", + " offset_counts (ArrayLike): Counts used to compute slice starts.\n", + " len_counts (ArrayLike, optional): Counts used to compute slice\n", + " lengths, offset_counts used if omitted.\n", + " step (int, optional): Slice step, None if omitted.\n", + "\n", + " Yields:\n", + " s (slice): Count-based slices for indexing.\n", + " \"\"\"\n", + " \n", + " offset_counts = np.asarray(offset_counts)\n", + " \n", + " if offset_counts.size == 0:\n", + " return\n", + " elif len_counts is None:\n", + " len_counts = offset_counts\n", + " else:\n", + " len_counts = np.asarray(len_counts)\n", + " \n", + " if offset_counts.size != len_counts.size:\n", + " raise ValueError('size of count arrays must match')\n", + "\n", + " yield np.s_[0:len_counts[0]:step]\n", + "\n", + " for offset, count in zip(np.cumsum(offset_counts)[:-1], len_counts[1:]):\n", + " yield np.s_[offset:offset+count:step]\n", + "\n", + "\n", "def correct_file(wid, index, work):\n", " aggregator, inp_path, outp_path = work\n", " module_index = int(aggregator[-2:])\n", @@ -433,24 +508,71 @@ " in_raw = inp_source['image.data'].ndarray().reshape(-1, 256, 256)\n", " in_cell = inp_source['image.cellId'].ndarray().reshape(-1)\n", " in_pulse = inp_source['image.pulseId'].ndarray().reshape(-1)\n", + " frame_counts = inp_source['image.data'].data_counts(labelled=False).astype(np.int32)\n", " read_time = perf_counter() - start\n", + "\n", + " parallel_gain_indices = None\n", + " \n", + " if parallel_gain:\n", + " assert (frame_counts % 3 == 0).all(), 'frame count not divisible by 3 in parallel gain mode'\n", + " actual_frame_counts = frame_counts // 3\n", + "\n", + " # Indices map where to find each of the high/medium/low gain images for each actual\n", + " # frame event.\n", + " parallel_gain_indices = np.zeros((actual_frame_counts.sum(), 3), dtype=np.int32)\n", + "\n", + " # Build indices for high gain as a range in each train, running from the cumulative sum\n", + " # of apparent frames from all trains before to the actual number of frames in this train.\n", + " np.concatenate([np.r_[s] for s in iter_count_slices(frame_counts, actual_frame_counts)],\n", + " out=parallel_gain_indices[:, 0])\n", + "\n", + " # The delta between the gain stages is the number of actual frames.\n", + " gain_index_deltas = np.repeat(actual_frame_counts, actual_frame_counts)\n", + "\n", + " # Build indices for medium gain and high gain by adding the gain index deltas in between\n", + " # each of them.\n", + " np.add(parallel_gain_indices[:, 0], gain_index_deltas, out=parallel_gain_indices[:, 1])\n", + " np.add(parallel_gain_indices[:, 1], gain_index_deltas, out=parallel_gain_indices[:, 2])\n", + "\n", + " assert parallel_gain_indices.max() <= in_raw.shape[0], 'gain image indices exceed raw data size'\n", + "\n", + " # Pick cell and pulse IDs from high gain. This is also done if frames are not combined\n", + " # in order to correct corrupt tables in medium and low gain, and if needed brought back\n", + " # to the original shape further below.\n", + " in_cell = np.take(in_cell, parallel_gain_indices[:, 0])\n", + " in_pulse = np.take(in_pulse, parallel_gain_indices[:, 0])\n", + "\n", + " if combine_parallel_gain:\n", + " # Replace supposed frame counts by actual frame counts.\n", + " frame_counts = actual_frame_counts\n", + " else:\n", + " # Replicate corrected cell and pulse IDs from high gain to other gains.\n", + " in_cell = np.concatenate([\n", + " np.tile(in_cell[s], 3) for s \n", + " in iter_count_slices(actual_frame_counts)])\n", + " in_pulse = np.concatenate([\n", + " np.tile(in_pulse[s], 3) for s\n", + " in iter_count_slices(actual_frame_counts)])\n", + " \n", + " # Disable gain indices to not combine.\n", + " parallel_gain_indices = None\n", " \n", " # Allocate output arrays.\n", - " out_data = np.zeros((in_raw.shape[0], 256, 256), dtype=np.float32)\n", - " out_gain = np.zeros((in_raw.shape[0], 256, 256), dtype=np.uint8)\n", - " out_mask = np.zeros((in_raw.shape[0], 256, 256), dtype=np.uint32)\n", + " num_frames = frame_counts.sum()\n", + " out_data = np.zeros((num_frames, 256, 256), dtype=np.float32)\n", + " out_gain = np.zeros((num_frames, 256, 256), dtype=np.uint8)\n", + " out_mask = np.zeros((num_frames, 256, 256), dtype=np.uint32)\n", " \n", " start = perf_counter()\n", " correct_lpd_frames(in_raw, in_cell,\n", " out_data, out_gain, out_mask,\n", - " ccv_offsets[aggregator], ccv_gains[aggregator], ccv_masks[aggregator],\n", - " num_threads=num_threads_per_worker)\n", + " ccv_offsets[aggregator], ccv_noise[aggregator], ccv_gains[aggregator], ccv_masks[aggregator],\n", + " parallel_gain_indices, threshold_sigma_high, threshold_sigma_mid,\n", + " num_threads=16)\n", " correct_time = perf_counter() - start\n", " \n", - " image_counts = inp_source['image.data'].data_counts(labelled=False)\n", - " \n", " start = perf_counter()\n", - " if (not outp_path.exists() or overwrite) and image_counts.sum() > 0:\n", + " if (not outp_path.exists() or overwrite) and num_frames > 0:\n", " outp_source_name = output_source.format(karabo_id=karabo_id, module_index=module_index)\n", "\n", " with DataFile(outp_path, 'w') as outp_file: \n", @@ -461,7 +583,7 @@ " \n", " outp_source = outp_file.create_instrument_source(outp_source_name)\n", " \n", - " outp_source.create_index(image=image_counts)\n", + " outp_source.create_index(image=frame_counts)\n", " outp_source.create_key('image.cellId', data=in_cell,\n", " chunks=(min(chunks_ids, in_cell.shape[0]),))\n", " outp_source.create_key('image.pulseId', data=in_pulse,\n", @@ -477,11 +599,13 @@ " write_time = perf_counter() - start\n", " \n", " total_time = open_time + read_time + correct_time + write_time\n", - " frame_rate = in_raw.shape[0] / total_time\n", + " frame_rate = num_frames / total_time\n", " \n", " print('{}\\t{}\\t{:.3f}\\t{:.3f}\\t{:.3f}\\t{:.3f}\\t{:.3f}\\t{}\\t{:.1f}'.format(\n", " wid, aggregator, open_time, read_time, correct_time, write_time, total_time,\n", - " in_raw.shape[0], frame_rate))\n", + " num_frames, frame_rate))\n", + "\n", + " worker_frame_counts[wid] += num_frames\n", " \n", " in_raw = None\n", " in_cell = None\n", @@ -492,10 +616,15 @@ " gc.collect()\n", "\n", "print('worker\\tDA\\topen\\tread\\tcorrect\\twrite\\ttotal\\tframes\\trate')\n", + "ctx = psh.ProcessContext(num_workers=num_workers)\n", + "\n", + "worker_frame_counts = ctx.alloc(shape=(), dtype=np.int32, per_worker=True)\n", "start = perf_counter()\n", - "psh.ProcessContext(num_workers=num_workers).map(correct_file, data_to_process)\n", + "ctx.map(correct_file, data_to_process)\n", "total_time = perf_counter() - start\n", - "print(f'Total time: {total_time:.1f}s')" + "total_frames = worker_frame_counts.sum()\n", + "\n", + "print(f'Total time: {total_time:.1f}s, Mean rate: {(total_frames / total_time):.1f}sâ»Â¹')" ] }, { @@ -517,14 +646,14 @@ "output_paths = [outp_path for _, _, outp_path in data_to_process if outp_path.exists()]\n", "\n", "if not output_paths:\n", - " warning('Data preview is skipped as there are no existing output paths')\n", + " warn('Data preview is skipped as there are no existing output paths')\n", " from sys import exit\n", " exit(0)\n", "\n", "dc = xd.DataCollection.from_paths(output_paths).select_trains(np.s_[0])\n", "\n", "det = LPD1M(dc, detector_name=karabo_id)\n", - "data = det.get_array('image.data')" + "data = det.get_array('image.data', unstack_pulses=False)" ] }, { @@ -581,7 +710,7 @@ "outputs": [], "source": [ "fig, ax = plt.subplots(num=2, figsize=(15, 15), clear=True, nrows=1, ncols=1)\n", - "geom.plot_data_fast(data[:, 0, 0], ax=ax, vmin=vmin, vmax=vmax)\n", + "geom.plot_data_fast(data[:, 0], ax=ax, vmin=vmin, vmax=vmax)\n", "pass" ] }, @@ -599,13 +728,12 @@ "ExecuteTime": { "end_time": "2018-11-13T18:24:57.547563Z", "start_time": "2018-11-13T18:24:56.995005Z" - }, - "scrolled": false + } }, "outputs": [], "source": [ "fig, ax = plt.subplots(num=3, figsize=(15, 15), clear=True, nrows=1, ncols=1)\n", - "geom.plot_data_fast(data[:, 0].mean(axis=1), ax=ax, vmin=vmin, vmax=vmax)\n", + "geom.plot_data_fast(data.mean(axis=1), ax=ax, vmin=vmin, vmax=vmax)\n", "pass" ] }, @@ -622,7 +750,7 @@ "metadata": {}, "outputs": [], "source": [ - "highest_gain_stage = det.get_array('image.gain', pulses=np.s_[:]).max(axis=(1, 2))\n", + "highest_gain_stage = det.get_array('image.gain', unstack_pulses=False).max(axis=1)\n", "\n", "fig, ax = plt.subplots(num=4, figsize=(15, 15), clear=True, nrows=1, ncols=1)\n", "p = geom.plot_data_fast(highest_gain_stage, ax=ax, vmin=0, vmax=2);\n", @@ -645,7 +773,7 @@ "metadata": {}, "outputs": [], "source": [ - "if create_virtual_cxi_in:\n", + "if create_virtual_cxi_in and not (parallel_gain and not combine_parallel_gain):\n", " vcxi_folder = Path(create_virtual_cxi_in.format(\n", " run=run, proposal_folder=str(Path(in_folder).parent)))\n", " vcxi_folder.mkdir(parents=True, exist_ok=True)\n", diff --git a/src/cal_tools/calcat_interface2.py b/src/cal_tools/calcat_interface2.py index 5643ce76814fd407b5028dd3307fd9d5b74c33e6..d662f2672841257ae9787331cca8075823067959 100644 --- a/src/cal_tools/calcat_interface2.py +++ b/src/cal_tools/calcat_interface2.py @@ -1032,10 +1032,9 @@ class LPDConditions(ConditionsBase): "Pixels X", "Pixels Y", "Feedback capacitor", - "Parallel gain", ] _dark_parameters = _base_params + [ - "Memory cell order", + "Memory cell order", "Parallel gain" ] _illuminated_parameters = _base_params + ["Source Energy", "category"] diff --git a/src/cal_tools/lpdalgs.pyx b/src/cal_tools/lpdalgs.pyx index 66b8b097c71a4a7cb6b1681e4678e0f04c2f43e7..6c746172de4cfc109d20c0bf3e3c780ee35e9d3d 100644 --- a/src/cal_tools/lpdalgs.pyx +++ b/src/cal_tools/lpdalgs.pyx @@ -1,13 +1,14 @@ + from cython cimport boundscheck, wraparound, cdivision from cython.view cimport contiguous from cython.parallel cimport prange from cal_tools.enums import BadPixels -ctypedef unsigned short cell_t ctypedef unsigned short raw_t ctypedef float data_t ctypedef unsigned char gain_t +ctypedef unsigned short cell_t ctypedef unsigned int mask_t cdef mask_t WRONG_GAIN_VALUE = BadPixels.WRONG_GAIN_VALUE, \ @@ -17,7 +18,6 @@ cdef mask_t WRONG_GAIN_VALUE = BadPixels.WRONG_GAIN_VALUE, \ @boundscheck(False) @wraparound(False) -@cdivision(True) def correct_lpd_frames( # (frame, x, y) raw_t[:, :, ::contiguous] in_raw, @@ -29,10 +29,17 @@ def correct_lpd_frames( mask_t[:, :, ::contiguous] out_mask, # (cell, x, y, gain) - float[:, :, :, ::contiguous] ccv_offset, - float[:, :, :, ::contiguous] ccv_gain, + data_t[:, :, :, ::contiguous] ccv_offset, + data_t[:, :, :, ::contiguous] ccv_noise, + data_t[:, :, :, ::contiguous] ccv_gain, mask_t[:, :, :, ::contiguous] ccv_mask, + # (frame, gain) + int[:, ::contiguous] parallel_gain_indices, + + data_t threshold_sigma_high, + data_t threshold_sigma_mid, + int num_threads=1, ): cdef int frame, ss, fs @@ -41,25 +48,34 @@ def correct_lpd_frames( cdef gain_t gain cdef mask_t mask - for frame in prange(in_raw.shape[0], nogil=True, num_threads=num_threads): + cdef bint adaptive_gain = parallel_gain_indices is None + + for frame in prange(out_data.shape[0], nogil=True, num_threads=num_threads): cell = in_cell[frame] for ss in range(in_raw.shape[1]): for fs in range(in_raw.shape[2]): - # Decode intensity and gain from raw data. - data = <data_t>(in_raw[frame, ss, fs] & 0xFFF) - gain = <gain_t>((in_raw[frame, ss, fs] & 0x3000) >> 12) + if adaptive_gain: + # Decode intensity and gain from raw data. + data = <data_t>(in_raw[frame, ss, fs] & 0xFFF) + gain = <gain_t>((in_raw[frame, ss, fs] & 0x3000) >> 12) + + else: + _parallel_gain_thresholding( + in_raw, parallel_gain_indices, ccv_noise, + threshold_sigma_high, threshold_sigma_mid, + frame, cell, ss, fs, + &data, &gain) if gain <= 2: + data = data - ccv_offset[cell, ss, fs, gain] + data = data * ccv_gain[cell, ss, fs, gain] mask = ccv_mask[cell, ss, fs, gain] else: data = 0.0 gain = 0 mask = WRONG_GAIN_VALUE - data = data - ccv_offset[cell, ss, fs, gain] - data = data * ccv_gain[cell, ss, fs, gain] - if data > 1e7 or data < -1e7: data = 0.0 mask = mask | VALUE_OUT_OF_RANGE @@ -67,3 +83,58 @@ def correct_lpd_frames( out_data[frame, ss, fs] = data out_gain[frame, ss, fs] = gain out_mask[frame, ss, fs] = mask + + +@boundscheck(False) +@wraparound(False) +cdef inline void _parallel_gain_thresholding( + raw_t[:, :, ::contiguous] in_raw, + int[:, ::contiguous] parallel_gain_indices, + data_t[:, :, :, ::contiguous] ccv_noise, + data_t sigma_high, data_t sigma_mid, + int frame, int cell, int ss, int fs, + data_t* data_ptr, gain_t* gain_ptr +) noexcept nogil: + cdef int frame_high, frame_mid, frame_low + cdef data_t data_high, data_mid, data_low + cdef data_t threshold_high, threshold_mid + + # Obtain indices to this pixel in each of three gain images. + frame_high = parallel_gain_indices[frame, 0] + frame_mid = parallel_gain_indices[frame, 1] + frame_low = parallel_gain_indices[frame, 2] + + if ( + ((in_raw[frame_high, ss, fs] & 0x3000) >> 12) == 0 and + ((in_raw[frame_mid, ss, fs] & 0x3000) >> 12) == 1 and + ((in_raw[frame_low, ss, fs] & 0x3000) >> 12) == 2 + ): + # Verify that this pixel is recorded in the correct gain stage + # in each of the gain images. Memory cells in the transition + # regions between the gains can sometimes end up in the wrong + # one. + + # Decode intensity in every gain stage. + data_high = <data_t>(in_raw[frame_high, ss, fs] & 0xFFF) + data_mid = <data_t>(in_raw[frame_mid, ss, fs] & 0xFFF) + data_low = <data_t>(in_raw[frame_low, ss, fs] & 0xFFF) + + # Compute thresholds based on noise level. + threshold_high = 4096 - sigma_high * ccv_noise[cell, ss, fs, 0] + threshold_mid = 4096 - sigma_mid * ccv_noise[cell, ss, fs, 1] + + # Pick the optimal gain stage for this pixel. + if data_mid > threshold_mid: + data_ptr[0] = data_low + gain_ptr[0] = 2 + elif data_high > threshold_high: + data_ptr[0] = data_mid + gain_ptr[0] = 1 + else: + data_ptr[0] = data_high + gain_ptr[0] = 0 + + else: + # Using an invalid gain stage triggers bad pixel masking later + # in the correction kernel. + gain_ptr[0] = 4