diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 0ed420a7ddc8c15e5b60995daeeb99bcb9dd6482..302d8550a2071af9544ce2044070b64c8a846eff 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -20,6 +20,7 @@ from sklearn.model_selection import train_test_split from sklearn.base import clone, MetaEstimatorMixin from joblib import Parallel, delayed +from functools import partial from typing import Any, Dict, List, Optional, Union, Tuple @@ -640,6 +641,24 @@ class Model(TransformerMixin, BaseEstimator): return high_res + def get_channel_quality(self, channel: str, low_res: Dict[str, np.ndarray], channel_pca_model: Pipeline) -> float: + """ + Get the compatibility for a single channel. + + Args: + channel: The channel. + low_res: The data in a dictionary. + pca_model: The PCA model. + + Returns: the compatibility factor. + """ + pca_model = channel_pca_model[channel].named_steps['pca'] + low_pca = pca_model.transform(low_res[channel]) + low_pca_rec = pca_model.inverse_transform(low_pca) + low_pca_unc = channel_pca_model.named_steps['unc'].uncertainty() + low_pca_dev = np.sqrt(np.mean((low_res[channel] - low_pca_rec)**2, axis=1, keepdims=True)) + return low_pca_dev/low_pca_unc + def check_compatibility_per_channel(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: """ Check if a new low-resolution data source is compatible with the one used in training, by @@ -653,13 +672,10 @@ class Model(TransformerMixin, BaseEstimator): selection_model = self.x_model['select'] low_res = selection_model.transform(low_res_data, keep_dictionary_structure=True) quality = {channel: 0.0 for channel in low_res.keys()} - for channel in low_res.keys(): - pca_model = self.channel_pca_model[channel].named_steps['pca'] - low_pca = pca_model.transform(low_res[channel]) - low_pca_rec = pca_model.inverse_transform(low_pca) - low_pca_unc = self.channel_pca_model.named_steps['unc'].uncertainty() - low_pca_dev = np.sqrt(np.mean((low_res[channel] - low_pca_rec)**2, axis=1, keepdims=True)) - quality[channel] = low_pca_dev/low_pca_unc + channels = list(low_res.keys()) + with mp.Pool(len(low_res.keys())) as p: + values = p.map(partial(self.get_channel_quality, low_res=low_res, channel_pca_model=self.channel_pca_model), channels) + quality = dict(zip(channels, values)) return quality def check_compatibility(self, low_res_data: Dict[str, np.ndarray]) -> np.ndarray: