From e9b1e09bf051da369baa789c8732a6812ceb03b5 Mon Sep 17 00:00:00 2001
From: Philipp Schmidt <philipp.schmidt@xfel.eu>
Date: Mon, 11 Apr 2022 14:04:46 +0200
Subject: [PATCH] Add support to recast image data to a different type in place
 in AGIPD correct

---
 .../AGIPD/AGIPD_Correct_and_Verify.ipynb      |  5 ++-
 src/cal_tools/agipdlib.py                     | 11 +++++-
 src/cal_tools/agipdutils.py                   | 36 +++++++++++++++++
 tests/test_agipdutils.py                      | 39 +++++++++++++++++++
 4 files changed, 89 insertions(+), 2 deletions(-)
 create mode 100644 tests/test_agipdutils.py

diff --git a/notebooks/AGIPD/AGIPD_Correct_and_Verify.ipynb b/notebooks/AGIPD/AGIPD_Correct_and_Verify.ipynb
index 1ecab067d..678a2b663 100644
--- a/notebooks/AGIPD/AGIPD_Correct_and_Verify.ipynb
+++ b/notebooks/AGIPD/AGIPD_Correct_and_Verify.ipynb
@@ -89,6 +89,7 @@
     "energy_threshold = -1000 # The low limit for the energy (uJ) exposed by frames subject to processing. If -1000, selection by pulse energy is disabled\n",
     "\n",
     "# Output parameters\n",
+    "recast_image_data = ''  # Cast data to a different dtype before saving\n",
     "compress_fields = ['gain', 'mask']  # Datasets in image group to compress.\n",
     "\n",
     "# Plotting parameters\n",
@@ -496,7 +497,9 @@
     "agipd_corr.noisy_adc_threshold = noisy_adc_threshold\n",
     "agipd_corr.ff_gain = ff_gain\n",
     "\n",
-    "agipd_corr.compress_fields = compress_fields",
+    "agipd_corr.compress_fields = compress_fields\n",
+    "if recast_image_data:\n",
+    "    agipd_corr.recast_image_fields['data'] = np.dtype(recast_image_data)"
    ]
   },
   {
diff --git a/src/cal_tools/agipdlib.py b/src/cal_tools/agipdlib.py
index cf63be6b5..69f4d4aab 100644
--- a/src/cal_tools/agipdlib.py
+++ b/src/cal_tools/agipdlib.py
@@ -22,6 +22,7 @@ from cal_tools.agipdutils import (
     make_noisy_adc_mask,
     match_asic_borders,
     melt_snowy_pixels,
+    cast_array_inplace
 )
 from cal_tools.enums import AgipdGainMode, BadPixels, SnowResolution
 from cal_tools.h5_copy_except import h5_copy_except_paths
@@ -345,6 +346,7 @@ class AgipdCorrections:
 
         # Output parameters
         self.compress_fields = ['gain', 'mask']
+        self.recast_image_fields = {}
 
         # Shared variables for data and constants
         self.shared_dict = []
@@ -472,7 +474,10 @@ class AgipdCorrections:
         agipd_base = f'INSTRUMENT/{self.h5_data_path}/'.format(module_idx)
         idx_base = self.h5_index_path.format(module_idx)
         data_path = f'{agipd_base}/image'
-        data_dict = self.shared_dict[i_proc]
+
+        # Obtain a shallow copy of the pointer map to allow for local
+        # changes in this method.
+        data_dict = self.shared_dict[i_proc].copy()
 
         image_fields = [
             'trainId', 'pulseId', 'cellId', 'data', 'gain', 'mask', 'blShift',
@@ -483,6 +488,10 @@ class AgipdCorrections:
             return
         trains = data_dict['trainId'][:n_img]
 
+        # Re-cast fields in-place, i.e. using the same memory region.
+        for field, dtype in self.recast_image_fields.items():
+            data_dict[field] = cast_array_inplace(data_dict[field], dtype)
+
         with h5py.File(ofile_name, "w") as outfile:
             # Copy any other data from the input file.
             # This includes indexes, so it's important that the corrected data
diff --git a/src/cal_tools/agipdutils.py b/src/cal_tools/agipdutils.py
index a5d7cb628..dd5657cf9 100644
--- a/src/cal_tools/agipdutils.py
+++ b/src/cal_tools/agipdutils.py
@@ -663,3 +663,39 @@ def melt_snowy_pixels(raw, im, gain, rgain, resolution=None):
                 snow_mask[k, i * 64:(i + 1) * 64,
                 j * 64:(j + 1) * 64] = asic_msk
     return im, gain, snow_mask
+
+
+def cast_array_inplace(inp, dtype):
+    """Cast an ndarray to a different dtype in place.
+
+    The resulting array will occupy the same memory as the input array,
+    and the cast will most likely make interpretating the buffer content
+    through the input array nonsensical.
+
+    Args:
+        inp (ndarray): Input array to cast, must be contiguous and
+            castable to the target dtype without copy.
+        dtype (DTypeLike): Data type to cast to.
+    """
+
+    # Save shape to recast later
+    orig_shape = inp.shape
+
+    # Create a new view of the input and flatten it in-place. Unlike
+    # inp.reshape(-1) this operation fails if a copy is required.
+    inp = inp.view()
+    inp.shape = inp.size
+
+    # Create a new view with the target dtype and slice it to the number
+    # of elements the input array has. This accounts for smaller dtypes
+    # using less space for the same size.
+    # The output array will be contiguous but not own its data.
+    outp = inp.view(dtype)[:len(inp)]
+
+    # "Copy" over the data, performing the cast.
+    outp[:] = inp
+
+    # Reshape back to the original.
+    outp.shape = orig_shape
+
+    return outp
diff --git a/tests/test_agipdutils.py b/tests/test_agipdutils.py
new file mode 100644
index 000000000..7c0f4f287
--- /dev/null
+++ b/tests/test_agipdutils.py
@@ -0,0 +1,39 @@
+
+import pytest
+import numpy as np
+
+from cal_tools.agipdutils import cast_array_inplace
+
+
+@pytest.mark.parametrize(
+    'dtype_str', ['f8', 'f4', 'f2', 'i4', 'i2', 'i1', 'u4', 'u2', 'u1'])
+def test_downcast_array_inplace(dtype_str):
+    """Test downcasting an array in-place."""
+
+    dtype = np.dtype(dtype_str)
+
+    ref_data = (np.random.rand(2, 3, 4) * 100)
+    orig_data = ref_data.copy()
+    cast_data = cast_array_inplace(orig_data, dtype)
+
+    np.testing.assert_allclose(cast_data, ref_data.astype(dtype))
+    assert np.may_share_memory(orig_data, cast_data)
+    assert cast_data.dtype == dtype
+    assert cast_data.flags.c_contiguous
+    assert cast_data.flags.aligned
+    assert not cast_data.flags.owndata
+
+
+def test_upcast_array_inplace():
+    """Test whether upcasting an array in-place fails."""
+
+    with pytest.raises(Exception):
+        cast_array_inplace(
+            np.random.rand(4, 5, 6).astype(np.float32), np.float64)
+
+
+def test_noncontiguous_cast_inplace():
+    """Test whether casting a non-contiguous array in-place fails."""
+
+    with pytest.raises(Exception):
+        cast_array_inplace(np.random.rand(4, 5, 6).T, np.int32)
-- 
GitLab