From 704d3f0b3442de60f3cee467e7bd4f29ad30eb49 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Fri, 3 Feb 2023 19:44:26 +0100
Subject: [PATCH] Fixed bugs.

---
 pes_to_spec/model.py | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 4af23c9..613bd3a 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -642,7 +642,7 @@ class Model(TransformerMixin, BaseEstimator):
 
         return high_res
 
-    def get_channel_quality(self, channel: str, low_res: Dict[str, np.ndarray], channel_pca_model: Pipeline) -> float:
+    def get_channel_quality(self, channel: str, low_res: Dict[str, np.ndarray], channel_pca_model: Dict[str, Pipeline]) -> float:
         """
         Get the compatibility for a single channel.
 
@@ -656,7 +656,7 @@ class Model(TransformerMixin, BaseEstimator):
         pca_model = channel_pca_model[channel].named_steps['pca']
         low_pca = pca_model.transform(low_res[channel])
         low_pca_rec = pca_model.inverse_transform(low_pca)
-        low_pca_unc = channel_pca_model.named_steps['unc'].uncertainty()
+        low_pca_unc = channel_pca_model[channel].named_steps['unc'].uncertainty()
         low_pca_dev =  np.sqrt(np.mean((low_res[channel] - low_pca_rec)**2, axis=1, keepdims=True))
         return low_pca_dev/low_pca_unc
 
@@ -674,8 +674,9 @@ class Model(TransformerMixin, BaseEstimator):
         low_res = selection_model.transform(low_res_data, keep_dictionary_structure=True)
         quality = {channel: 0.0 for channel in low_res.keys()}
         channels = list(low_res.keys())
-        with mp.Pool(len(low_res.keys())) as p:
-            values = p.map(partial(self.get_channel_quality, low_res=low_res, channel_pca_model=self.channel_pca_model), channels)
+        #with mp.Pool(len(low_res.keys())) as p:
+        #    values = p.map(partial(self.get_channel_quality, low_res=low_res, channel_pca_model=self.channel_pca_model), channels)
+        values = map(partial(self.get_channel_quality, low_res=low_res, channel_pca_model=self.channel_pca_model), channels)
         quality = dict(zip(channels, values))
         return quality
 
-- 
GitLab