From 753dec56ee85d70888196ba2ed65db2f815a09e2 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Tue, 31 Jan 2023 18:08:09 +0100
Subject: [PATCH] Fix and parallelize inference of the compatibility per
 channel.

---
 pes_to_spec/model.py | 30 +++++++++++++++++++++++-------
 1 file changed, 23 insertions(+), 7 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 0ed420a..302d855 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -20,6 +20,7 @@ from sklearn.model_selection import train_test_split
 
 from sklearn.base import clone, MetaEstimatorMixin
 from joblib import Parallel, delayed
+from functools import partial
 
 from typing import Any, Dict, List, Optional, Union, Tuple
 
@@ -640,6 +641,24 @@ class Model(TransformerMixin, BaseEstimator):
 
         return high_res
 
+    def get_channel_quality(self, channel: str, low_res: Dict[str, np.ndarray], channel_pca_model: Pipeline) -> float:
+        """
+        Get the compatibility for a single channel.
+
+        Args:
+          channel: The channel.
+          low_res: The data in a dictionary.
+          pca_model: The PCA model.
+
+        Returns: the compatibility factor.
+        """
+        pca_model = channel_pca_model[channel].named_steps['pca']
+        low_pca = pca_model.transform(low_res[channel])
+        low_pca_rec = pca_model.inverse_transform(low_pca)
+        low_pca_unc = channel_pca_model.named_steps['unc'].uncertainty()
+        low_pca_dev =  np.sqrt(np.mean((low_res[channel] - low_pca_rec)**2, axis=1, keepdims=True))
+        return low_pca_dev/low_pca_unc
+
     def check_compatibility_per_channel(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
         """
         Check if a new low-resolution data source is compatible with the one used in training, by
@@ -653,13 +672,10 @@ class Model(TransformerMixin, BaseEstimator):
         selection_model = self.x_model['select']
         low_res = selection_model.transform(low_res_data, keep_dictionary_structure=True)
         quality = {channel: 0.0 for channel in low_res.keys()}
-        for channel in low_res.keys():
-            pca_model = self.channel_pca_model[channel].named_steps['pca']
-            low_pca = pca_model.transform(low_res[channel])
-            low_pca_rec = pca_model.inverse_transform(low_pca)
-            low_pca_unc = self.channel_pca_model.named_steps['unc'].uncertainty()
-            low_pca_dev =  np.sqrt(np.mean((low_res[channel] - low_pca_rec)**2, axis=1, keepdims=True))
-            quality[channel] = low_pca_dev/low_pca_unc
+        channels = list(low_res.keys())
+        with mp.Pool(len(low_res.keys())) as p:
+            values = p.map(partial(self.get_channel_quality, low_res=low_res, channel_pca_model=self.channel_pca_model), channels)
+        quality = dict(zip(channels, values))
         return quality
 
     def check_compatibility(self, low_res_data: Dict[str, np.ndarray]) -> np.ndarray:
-- 
GitLab