Skip to content
Snippets Groups Projects
Commit 753dec56 authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Fix and parallelize inference of the compatibility per channel.

parent 0c23a584
No related branches found
No related tags found
1 merge request!5Check consistency per channel
...@@ -20,6 +20,7 @@ from sklearn.model_selection import train_test_split ...@@ -20,6 +20,7 @@ from sklearn.model_selection import train_test_split
from sklearn.base import clone, MetaEstimatorMixin from sklearn.base import clone, MetaEstimatorMixin
from joblib import Parallel, delayed from joblib import Parallel, delayed
from functools import partial
from typing import Any, Dict, List, Optional, Union, Tuple from typing import Any, Dict, List, Optional, Union, Tuple
...@@ -640,6 +641,24 @@ class Model(TransformerMixin, BaseEstimator): ...@@ -640,6 +641,24 @@ class Model(TransformerMixin, BaseEstimator):
return high_res return high_res
def get_channel_quality(self, channel: str, low_res: Dict[str, np.ndarray], channel_pca_model: Pipeline) -> float:
"""
Get the compatibility for a single channel.
Args:
channel: The channel.
low_res: The data in a dictionary.
pca_model: The PCA model.
Returns: the compatibility factor.
"""
pca_model = channel_pca_model[channel].named_steps['pca']
low_pca = pca_model.transform(low_res[channel])
low_pca_rec = pca_model.inverse_transform(low_pca)
low_pca_unc = channel_pca_model.named_steps['unc'].uncertainty()
low_pca_dev = np.sqrt(np.mean((low_res[channel] - low_pca_rec)**2, axis=1, keepdims=True))
return low_pca_dev/low_pca_unc
def check_compatibility_per_channel(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: def check_compatibility_per_channel(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
""" """
Check if a new low-resolution data source is compatible with the one used in training, by Check if a new low-resolution data source is compatible with the one used in training, by
...@@ -653,13 +672,10 @@ class Model(TransformerMixin, BaseEstimator): ...@@ -653,13 +672,10 @@ class Model(TransformerMixin, BaseEstimator):
selection_model = self.x_model['select'] selection_model = self.x_model['select']
low_res = selection_model.transform(low_res_data, keep_dictionary_structure=True) low_res = selection_model.transform(low_res_data, keep_dictionary_structure=True)
quality = {channel: 0.0 for channel in low_res.keys()} quality = {channel: 0.0 for channel in low_res.keys()}
for channel in low_res.keys(): channels = list(low_res.keys())
pca_model = self.channel_pca_model[channel].named_steps['pca'] with mp.Pool(len(low_res.keys())) as p:
low_pca = pca_model.transform(low_res[channel]) values = p.map(partial(self.get_channel_quality, low_res=low_res, channel_pca_model=self.channel_pca_model), channels)
low_pca_rec = pca_model.inverse_transform(low_pca) quality = dict(zip(channels, values))
low_pca_unc = self.channel_pca_model.named_steps['unc'].uncertainty()
low_pca_dev = np.sqrt(np.mean((low_res[channel] - low_pca_rec)**2, axis=1, keepdims=True))
quality[channel] = low_pca_dev/low_pca_unc
return quality return quality
def check_compatibility(self, low_res_data: Dict[str, np.ndarray]) -> np.ndarray: def check_compatibility(self, low_res_data: Dict[str, np.ndarray]) -> np.ndarray:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment