diff --git a/notebooks/Jungfrau/Jungfrau_Gain_Correct_and_Verify_NBC.ipynb b/notebooks/Jungfrau/Jungfrau_Gain_Correct_and_Verify_NBC.ipynb index f7ccbfff05bcdba173017293ee2e0338f8c0986c..3cff265cbb2f1314c335a23ffdad9144d947b3c8 100644 --- a/notebooks/Jungfrau/Jungfrau_Gain_Correct_and_Verify_NBC.ipynb +++ b/notebooks/Jungfrau/Jungfrau_Gain_Correct_and_Verify_NBC.ipynb @@ -38,7 +38,9 @@ "cal_db_timeout = 180000 # timeout on caldb requests\n", "\n", "# Parameters affecting corrected data.\n", - "relative_gain = True # do relative gain correction\n", + "relative_gain = True # do relative gain correction.\n", + "strixel_sensor = False # reordering for strixel detector layout.\n", + "strixel_double_norm = 2.0 # normalization to use for double-size pixels, only applied for strixel sensors.\n", "limit_trains = 0 # ONLY FOR TESTING. process only first N trains, Use 0 to process all.\n", "chunks_ids = 32 # HDF chunk size for memoryCell and frameNumber.\n", "chunks_data = 1 # HDF chunk size for pixel data in number of frames.\n", @@ -208,6 +210,22 @@ "print(f\"Number of memory cells are {memory_cells}\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "strixel_transform = None\n", + "output_frame_shape = None\n", + "\n", + "if strixel_sensor:\n", + " from cal_tools.jfalgs import strixel_transform, strixel_shape, strixel_double_pixels\n", + " output_frame_shape = strixel_shape()\n", + " Ydouble, Xdouble = strixel_double_pixels()\n", + " print('Strixel sensor transformation enabled')" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -340,8 +358,15 @@ "def correct_train(wid, index, d):\n", " d = d.astype(np.float32) # [cells, x, y]\n", " g = gain[index]\n", + " \n", + " # Copy gain over first to keep it at the original 3 for low gain.\n", + " if strixel_transform is not None:\n", + " strixel_transform(g, out=gain_corr[index, ...])\n", + " else:\n", + " gain_corr[index, ...] = g\n", "\n", " # Jungfrau gains 0[00], 1[01], 3[11]\n", + " # Change low gain to 2 for indexing purposes.\n", " g[g==3] = 2\n", "\n", " # Select memory cells\n", @@ -380,8 +405,14 @@ "\n", " msk = np.choose(g, np.moveaxis(mask_cell, -1, 0))\n", "\n", - " data_corr[index, ...] = d\n", - " mask_corr[index, ...] = msk" + " if strixel_transform is not None:\n", + " strixel_transform(d, out=data_corr[index, ...])\n", + " data_corr[index, :, Ydouble, Xdouble] /= strixel_double_norm\n", + "\n", + " strixel_transform(msk, out=mask_corr[index, ...])\n", + " else:\n", + " data_corr[index, ...] = d\n", + " mask_corr[index, ...] = msk" ] }, { @@ -464,8 +495,8 @@ " # and number of available trains to correct.\n", " seq_dc = H5File(sequence_file)\n", " seq_dc_adc = seq_dc[instrument_src_kda, \"data.adc\"]\n", - " dshape = seq_dc_adc.shape\n", - " corr_ntrains = seq_dc_adc.shape[0] # number of available trains to correct.\n", + " ishape = seq_dc_adc.shape # input shape.\n", + " corr_ntrains = ishape[0] # number of available trains to correct.\n", " all_train_ids = seq_dc_adc.train_ids\n", "\n", " # Raise a WARNING if this sequence has no trains to correct.\n", @@ -488,10 +519,17 @@ " # Load constants from the constants dictionary.\n", " # These arrays are used by `correct_train()` function\n", " offset_map, mask, gain_map = constants[local_karabo_da]\n", + " \n", + " # Determine total output shape.\n", + " if output_frame_shape is not None:\n", + " oshape = (*ishape[:-2], *output_frame_shape)\n", + " else:\n", + " oshape = ishape\n", "\n", " # Allocate shared arrays for corrected data. Used in `correct_train()`\n", - " data_corr = context.alloc(shape=(corr_ntrains, *dshape[1:]), dtype=np.float32)\n", - " mask_corr = context.alloc(shape=(corr_ntrains, *dshape[1:]), dtype=np.uint32)\n", + " data_corr = context.alloc(shape=oshape, dtype=np.float32)\n", + " gain_corr = context.alloc(shape=oshape, dtype=np.uint8)\n", + " mask_corr = context.alloc(shape=oshape, dtype=np.uint32)\n", "\n", " step_timer.start()\n", " # Overwrite seq_dc after eliminating empty trains or/and applying limited images.\n", @@ -551,10 +589,10 @@ " # Add main corrected `data.adc`` dataset and store corrected data.\n", " outp_source.create_key(\n", " \"data.adc\", data=data_corr,\n", - " chunks=(min(chunks_data, data_corr.shape[0]), *dshape[1:]))\n", + " chunks=(min(chunks_data, data_corr.shape[0]), *oshape[1:]))\n", "\n", " write_compressed_frames(\n", - " gain, outp_file, f\"{outp_source.name}/data/gain\", comp_threads=8)\n", + " gain_corr, outp_file, f\"{outp_source.name}/data/gain\", comp_threads=8)\n", " write_compressed_frames(\n", " mask_corr, outp_file, f\"{outp_source.name}/data/mask\", comp_threads=8)\n", "\n", @@ -739,14 +777,21 @@ "corrected_mean = np.mean(corrected, axis=1)\n", "_corrected_vmin = min(0.75*np.median(corrected_mean[corrected_mean > 0]), -0.5)\n", "_corrected_vmax = max(2.*np.median(corrected_mean[corrected_mean > 0]), 100)\n", - "geom.plot_data_fast(\n", - " corrected_mean,\n", - " ax=ax,\n", - " vmin=_corrected_vmin,\n", - " vmax=_corrected_vmax,\n", - " cmap=\"jet\",\n", - " colorbar={'shrink': 1, 'pad': 0.01},\n", + "\n", + "mean_plot_kwargs = dict(\n", + " vmin=_corrected_vmin, vmax=_corrected_vmax, cmap=\"jet\"\n", ")\n", + "\n", + "if not strixel_sensor:\n", + " geom.plot_data_fast(\n", + " corrected_mean,\n", + " ax=ax,\n", + " colorbar={'shrink': 1, 'pad': 0.01},\n", + " **mean_plot_kwargs\n", + " )\n", + "else:\n", + " ax.imshow(corrected_mean.squeeze(), aspect=10, **mean_plot_kwargs)\n", + " \n", "ax.set_title(f'{karabo_id} - Mean CORRECTED', size=18)\n", "\n", "plt.show()" @@ -763,14 +808,17 @@ "corrected_masked[mask != 0] = np.nan\n", "corrected_masked_mean = np.nanmean(corrected_masked, axis=1)\n", "del corrected_masked\n", - "geom.plot_data_fast(\n", - " corrected_masked_mean,\n", - " ax=ax,\n", - " vmin=_corrected_vmin,\n", - " vmax=_corrected_vmax,\n", - " cmap=\"jet\",\n", - " colorbar={'shrink': 1, 'pad': 0.01},\n", - ")\n", + "\n", + "if not strixel_sensor:\n", + " geom.plot_data_fast(\n", + " corrected_masked_mean,\n", + " ax=ax,\n", + " colorbar={'shrink': 1, 'pad': 0.01},\n", + " **mean_plot_kwargs\n", + " )\n", + "else:\n", + " ax.imshow(corrected_mean.squeeze(), aspect=10, **mean_plot_kwargs)\n", + "\n", "ax.set_title(f'{karabo_id} - Mean CORRECTED with mask', size=18)\n", "\n", "plt.show()" @@ -785,14 +833,23 @@ "display(Markdown((f\"#### A single image from train {tid}\")))\n", "\n", "fig, ax = plt.subplots(figsize=(18, 10))\n", - "geom.plot_data_fast(\n", - " corrected_train,\n", - " ax=ax,\n", + "\n", + "single_plot_kwargs = dict(\n", " vmin=min(0.75 * np.median(corrected_train[corrected_train > 0]), -0.5),\n", " vmax=max(2.0 * np.median(corrected_train[corrected_train > 0]), 100),\n", - " cmap=\"jet\",\n", - " colorbar={\"shrink\": 1, \"pad\": 0.01},\n", + " cmap=\"jet\"\n", ")\n", + "\n", + "if not strixel_sensor:\n", + " geom.plot_data_fast(\n", + " corrected_train,\n", + " ax=ax,\n", + " colorbar={\"shrink\": 1, \"pad\": 0.01},\n", + " **single_plot_kwargs\n", + " )\n", + "else:\n", + " ax.imshow(corrected_train.squeeze(), aspect=10, **single_plot_kwargs)\n", + "\n", "ax.set_title(f\"{karabo_id} - CORRECTED train: {tid}\", size=18)\n", "\n", "plt.show()" @@ -951,12 +1008,16 @@ "display(Markdown(f\"#### Bad pixels image for train {tid}\"))\n", "\n", "fig, ax = plt.subplots(figsize=(18, 10))\n", - "geom.plot_data_fast(\n", - " np.log2(mask_train),\n", - " ax=ax,\n", - " vmin=0, vmax=32, cmap=\"jet\",\n", - " colorbar={'shrink': 1, 'pad': 0.01},\n", - ")\n", + "if not strixel_sensor:\n", + " geom.plot_data_fast(\n", + " np.log2(mask_train),\n", + " ax=ax,\n", + " vmin=0, vmax=32, cmap=\"jet\",\n", + " colorbar={'shrink': 1, 'pad': 0.01},\n", + " )\n", + "else:\n", + " ax.imshow(np.log2(mask_train).squeeze(), vmin=0, vmax=32, cmap='jet', aspect=10)\n", + "\n", "plt.show()" ] } diff --git a/setup.py b/setup.py index d7324fd81b939b045e2404e8b7fcc839f3e94cb3..e5561837146b2b89f3cf75bcb9d62dcf07840ed4 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,12 @@ ext_modules = [ '-ftree-vectorize', '-frename-registers'], extra_link_args=['-fopenmp'], ), + Extension( + "cal_tools.jfalgs", + ['src/cal_tools/jfalgs.pyx'], + extra_compile_args=['-O3', '-march=native', '-ftree-vectorize', + '-frename-registers'] + ), Extension( "cal_tools.gotthard2.gotthard2algs", ["src/cal_tools/gotthard2/gotthard2algs.pyx"], diff --git a/src/cal_tools/jfalgs.pyx b/src/cal_tools/jfalgs.pyx new file mode 100644 index 0000000000000000000000000000000000000000..fd1476918d8a8f165150abe1c5bbeb6175721913 --- /dev/null +++ b/src/cal_tools/jfalgs.pyx @@ -0,0 +1,93 @@ + +from cython cimport boundscheck, wraparound, cdivision +from cython.view cimport contiguous + +import numpy as np + + +ctypedef fused jf_data_t: + unsigned short # raw pixel data + float # corrected pixel data + unsigned int # mask data + unsigned char # gain data + + +DEF STRIXEL_Y = 86 +DEF STRIXEL_X = 1024 * 3 + 18 + + +def strixel_shape(): + return STRIXEL_Y, STRIXEL_X + + +def strixel_double_pixels(): + """Build index arrays for double-size pixels. + + In raw data, the entire columns 255, 256, 511, 512, 767 and 768 + are double-size pixels. After strixelation, these end up in columns + 765-776, 1539-1550 and 2313-2324 on rows 0-85 or 0-83, with a set + of four columns with 86 rows followed by a set of 84 and 86 again. + + This function builds the index arrays after strixelation. + """ + + Ydouble = [] + Xdouble = [] + + for double_col in [765, 1539, 2313]: + for col in range(double_col, double_col+12): + for row in range(84 if ((col-double_col) // 4) == 1 else 86): + Ydouble.append(row) + Xdouble.append(col) + + return np.array(Ydouble), np.array(Xdouble) + + +@boundscheck(False) +@wraparound(False) +@cdivision(True) +def strixel_transform(jf_data_t[:, :, ::contiguous] data, + jf_data_t[:, :, ::contiguous] out = None): + """Reorder raw data to physical strixel sensor layout. """ + + if data.shape[1] < 256 or data.shape[2] < 256: + raise ValueError('Pixel shape of data may not be below (256, 256)') + + if out is None: + import numpy as np + out = np.zeros((data.shape[0], STRIXEL_Y, STRIXEL_X), dtype=np.float32) + elif data.shape[0] > out.shape[0]: + raise ValueError('Cell shape of data exceeds out') + elif out.shape[1] < STRIXEL_Y or out.shape[2] < STRIXEL_X: + raise ValueError(f'Pixel shape of out may not be below ' + f'({STRIXEL_Y}, {STRIXEL_X})') + + cdef int cell, yin, xin, xout, yout, igap + + for cell in range(data.shape[0]): + # Normal pixels. + for yin in range(256): + yout = yin // 3 + + for xin in range(1024) : + xout = 774 * (xin // 256) + 3 * (xin % 256) + yin % 3 + out[cell, yout, xout] = data[cell, yin, xin] + + # Gap pixels. + for yin in range(256): + yout = 2 * (yin // 6) + + for igap in range(3) : + # Left side of the gap. + xin = igap * 256 + 255 + xout = igap * 774 + 765 + yin % 6 + out[cell, yout, xout] = data[cell, yin, xin] + out[cell, yout+1, xout] = data[cell, yin, xin] + + # Right side of the gap. + xin = igap * 256 + 255 + 1 + xout = igap * 774 + 765 + 11 - yin % 6 + out[cell, yout, xout] = data[cell, yin, xin] + out[cell, yout+1, xout] = data[cell, yin, xin] + + return out