From e9b9aacf21c042bc46c66ee7eacddfdabc3b52ed Mon Sep 17 00:00:00 2001
From: Philipp Schmidt <philipp.schmidt@xfel.eu>
Date: Tue, 1 Nov 2022 13:52:57 +0100
Subject: [PATCH] Add parallel compression via write_compressed_frame to
 DataFile

---
 ...Jungfrau_Gain_Correct_and_Verify_NBC.ipynb | 10 +++----
 notebooks/LPD/LPD_Correct_Fast.ipynb          | 12 ++------
 src/cal_tools/files.py                        | 28 +++++++++++++++++++
 3 files changed, 35 insertions(+), 15 deletions(-)

diff --git a/notebooks/Jungfrau/Jungfrau_Gain_Correct_and_Verify_NBC.ipynb b/notebooks/Jungfrau/Jungfrau_Gain_Correct_and_Verify_NBC.ipynb
index 3cff265cb..0ed635173 100644
--- a/notebooks/Jungfrau/Jungfrau_Gain_Correct_and_Verify_NBC.ipynb
+++ b/notebooks/Jungfrau/Jungfrau_Gain_Correct_and_Verify_NBC.ipynb
@@ -99,7 +99,6 @@
     "    get_dir_creation_date,\n",
     "    get_pdu_from_db,\n",
     "    map_seq_files,\n",
-    "    write_compressed_frames,\n",
     "    CalibrationMetadata,\n",
     ")\n",
     "from iCalibrationDB import Conditions, Constants\n",
@@ -590,11 +589,10 @@
     "            outp_source.create_key(\n",
     "                \"data.adc\", data=data_corr,\n",
     "                chunks=(min(chunks_data, data_corr.shape[0]), *oshape[1:]))\n",
-    "\n",
-    "            write_compressed_frames(\n",
-    "                gain_corr, outp_file, f\"{outp_source.name}/data/gain\", comp_threads=8)\n",
-    "            write_compressed_frames(\n",
-    "                mask_corr, outp_file, f\"{outp_source.name}/data/mask\", comp_threads=8)\n",
+    "            outp_source.create_compressed_key(\n",
+    "                \"data.gain\", data=gain_corr)\n",
+    "            outp_source.create_compressed_key(\n",
+    "                \"data.mask\", data=mask_corr)\n",
     "\n",
     "            save_reduced_rois(outp_file, data_corr, mask_corr, local_karabo_da)\n",
     "\n",
diff --git a/notebooks/LPD/LPD_Correct_Fast.ipynb b/notebooks/LPD/LPD_Correct_Fast.ipynb
index 73906ff03..1bb9f2773 100644
--- a/notebooks/LPD/LPD_Correct_Fast.ipynb
+++ b/notebooks/LPD/LPD_Correct_Fast.ipynb
@@ -105,11 +105,7 @@
     "from extra_data.components import LPD1M\n",
     "\n",
     "from cal_tools.lpdalgs import correct_lpd_frames\n",
-    "from cal_tools.tools import (\n",
-    "    CalibrationMetadata,\n",
-    "    calcat_creation_time,\n",
-    "    write_compressed_frames,\n",
-    "    )\n",
+    "from cal_tools.tools import CalibrationMetadata, calcat_creation_time\n",
     "from cal_tools.files import DataFile\n",
     "from cal_tools.restful_config import restful_config"
    ]
@@ -463,10 +459,8 @@
     "                                   chunks=(min(chunks_ids, in_pulse.shape[0]),))\n",
     "            outp_source.create_key('image.data', data=out_data,\n",
     "                                   chunks=(min(chunks_data, out_data.shape[0]), 256, 256))\n",
-    "            write_compressed_frames(\n",
-    "                out_gain, outp_file, f'INSTRUMENT/{outp_source_name}/image/gain', comp_threads=8)\n",
-    "            write_compressed_frames(\n",
-    "                out_mask, outp_file, f'INSTRUMENT/{outp_source_name}/image/mask', comp_threads=8)\n",
+    "            outp_source.create_compressed_key('image.gain', data=out_gain)\n",
+    "            outp_source.create_compressed_key('image.mask', data=out_mask)\n",
     "    write_time = perf_counter() - start\n",
     "    \n",
     "    total_time = open_time + read_time + correct_time + write_time\n",
diff --git a/src/cal_tools/files.py b/src/cal_tools/files.py
index 21fb85520..f1e13d6f5 100644
--- a/src/cal_tools/files.py
+++ b/src/cal_tools/files.py
@@ -517,6 +517,34 @@ class InstrumentSource(h5py.Group):
 
         return self.create_dataset(key, data=data, **kwargs)
 
+    def create_compressed_key(self, key, data, comp_threads=8):
+        """Create a compressed dataset for a key.
+
+        This method makes use of lower-level access in h5py to compress
+        the data separately in multiple threads and write it directly to
+        file rather than go through HDF's compression filters.
+
+        Args:
+            key (str): Source key, dots are automatically replaced by
+                slashes.
+            data (np.ndarray): Key data.ss
+            comp_threads (int, optional): Number of threads to use for
+                compression, 8 by default.
+
+        Returns:
+            (h5py.Dataset) Created dataset
+        """
+
+        key = escape_key(key)
+
+        if not self.key_pattern.match(key):
+            raise ValueError(f'invalid key format, must satisfy '
+                             f'{self.key_pattern.pattern}')
+
+        from cal_tools.tools import write_compressed_frames
+        return write_compressed_frames(data, self, key,
+                                       comp_threads=comp_threads)
+
     def create_index(self, *args, **channels):
         """Create source-specific INDEX datasets.
 
-- 
GitLab