Skip to content
Snippets Groups Projects
Commit 704d3f0b authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Fixed bugs.

parent e7ace9db
No related branches found
No related tags found
1 merge request!5Check consistency per channel
......@@ -642,7 +642,7 @@ class Model(TransformerMixin, BaseEstimator):
return high_res
def get_channel_quality(self, channel: str, low_res: Dict[str, np.ndarray], channel_pca_model: Pipeline) -> float:
def get_channel_quality(self, channel: str, low_res: Dict[str, np.ndarray], channel_pca_model: Dict[str, Pipeline]) -> float:
"""
Get the compatibility for a single channel.
......@@ -656,7 +656,7 @@ class Model(TransformerMixin, BaseEstimator):
pca_model = channel_pca_model[channel].named_steps['pca']
low_pca = pca_model.transform(low_res[channel])
low_pca_rec = pca_model.inverse_transform(low_pca)
low_pca_unc = channel_pca_model.named_steps['unc'].uncertainty()
low_pca_unc = channel_pca_model[channel].named_steps['unc'].uncertainty()
low_pca_dev = np.sqrt(np.mean((low_res[channel] - low_pca_rec)**2, axis=1, keepdims=True))
return low_pca_dev/low_pca_unc
......@@ -674,8 +674,9 @@ class Model(TransformerMixin, BaseEstimator):
low_res = selection_model.transform(low_res_data, keep_dictionary_structure=True)
quality = {channel: 0.0 for channel in low_res.keys()}
channels = list(low_res.keys())
with mp.Pool(len(low_res.keys())) as p:
values = p.map(partial(self.get_channel_quality, low_res=low_res, channel_pca_model=self.channel_pca_model), channels)
#with mp.Pool(len(low_res.keys())) as p:
# values = p.map(partial(self.get_channel_quality, low_res=low_res, channel_pca_model=self.channel_pca_model), channels)
values = map(partial(self.get_channel_quality, low_res=low_res, channel_pca_model=self.channel_pca_model), channels)
quality = dict(zip(channels, values))
return quality
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment