From 753dec56ee85d70888196ba2ed65db2f815a09e2 Mon Sep 17 00:00:00 2001 From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de> Date: Tue, 31 Jan 2023 18:08:09 +0100 Subject: [PATCH] Fix and parallelize inference of the compatibility per channel. --- pes_to_spec/model.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 0ed420a..302d855 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -20,6 +20,7 @@ from sklearn.model_selection import train_test_split from sklearn.base import clone, MetaEstimatorMixin from joblib import Parallel, delayed +from functools import partial from typing import Any, Dict, List, Optional, Union, Tuple @@ -640,6 +641,24 @@ class Model(TransformerMixin, BaseEstimator): return high_res + def get_channel_quality(self, channel: str, low_res: Dict[str, np.ndarray], channel_pca_model: Pipeline) -> float: + """ + Get the compatibility for a single channel. + + Args: + channel: The channel. + low_res: The data in a dictionary. + pca_model: The PCA model. + + Returns: the compatibility factor. + """ + pca_model = channel_pca_model[channel].named_steps['pca'] + low_pca = pca_model.transform(low_res[channel]) + low_pca_rec = pca_model.inverse_transform(low_pca) + low_pca_unc = channel_pca_model.named_steps['unc'].uncertainty() + low_pca_dev = np.sqrt(np.mean((low_res[channel] - low_pca_rec)**2, axis=1, keepdims=True)) + return low_pca_dev/low_pca_unc + def check_compatibility_per_channel(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: """ Check if a new low-resolution data source is compatible with the one used in training, by @@ -653,13 +672,10 @@ class Model(TransformerMixin, BaseEstimator): selection_model = self.x_model['select'] low_res = selection_model.transform(low_res_data, keep_dictionary_structure=True) quality = {channel: 0.0 for channel in low_res.keys()} - for channel in low_res.keys(): - pca_model = self.channel_pca_model[channel].named_steps['pca'] - low_pca = pca_model.transform(low_res[channel]) - low_pca_rec = pca_model.inverse_transform(low_pca) - low_pca_unc = self.channel_pca_model.named_steps['unc'].uncertainty() - low_pca_dev = np.sqrt(np.mean((low_res[channel] - low_pca_rec)**2, axis=1, keepdims=True)) - quality[channel] = low_pca_dev/low_pca_unc + channels = list(low_res.keys()) + with mp.Pool(len(low_res.keys())) as p: + values = p.map(partial(self.get_channel_quality, low_res=low_res, channel_pca_model=self.channel_pca_model), channels) + quality = dict(zip(channels, values)) return quality def check_compatibility(self, low_res_data: Dict[str, np.ndarray]) -> np.ndarray: -- GitLab