Skip to content
Snippets Groups Projects
Commit 1dfda8ed authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Using channel relevance as a per channel indicator of compatibility.

parent 7d76f17b
No related branches found
No related tags found
1 merge request!6Use relevance per channel as a measurement of channel-compatibility
......@@ -545,10 +545,8 @@ class Model(TransformerMixin, BaseEstimator):
#self.fit_model = FitModel()
self.fit_model = MultiOutputWithStd(ARDRegression(n_iter=30, tol=1e-4, verbose=True))
self.channel_pca_model = {channel: Pipeline([('pca', PCA(n_pca_lr, whiten=True)),
('unc', UncertaintyHolder()),
])
for channel in channels}
self.channel_mean = {ch: np.nan for ch in channels}
self.channel_relevance = {ch: np.nan for ch in channels}
# size of the test subset
self.validation_size = validation_size
......@@ -629,20 +627,28 @@ class Model(TransformerMixin, BaseEstimator):
self.x_model['unc'].set_uncertainty(low_pca_unc)
# for consistency check per channel
print("Calculate PCA per channel on low-resolution data.")
selection_model = self.x_model['select']
low_res = selection_model.transform(low_res_data, keep_dictionary_structure=True)
for channel in self.get_channels():
print(f"Calculate PCA on {channel}")
low_pca = self.channel_pca_model[channel].named_steps["pca"].fit_transform(low_res[channel])
low_pca_rec = self.channel_pca_model[channel].named_steps["pca"].inverse_transform(low_pca)
low_pca_unc = np.mean(np.sqrt(np.mean((low_res[channel] - low_pca_rec)**2, axis=1, keepdims=True)), axis=0, keepdims=True)
self.channel_pca_model[channel]['unc'].set_uncertainty(low_pca_unc)
self.channel_mean[channel] = np.mean(low_res_data[channel], axis=0, keepdims=True)
print(f"Calculate PCA relevance on {channel}")
# freeze input data in one channel only
low_res_data_frozen = {ch: low_res_data[ch] if ch != channel
else np.repeat(self.channel_mean[channel], low_res_data[ch].shape[0], axis=0)
for ch in self.get_channels()}
low_res = selection_model.transform(low_res_data_frozen)
low_pca = pca_model.fit_transform(low_res)
low_pca_rec = pca_model.inverse_transform(low_pca)
low_pca_unc = np.mean(np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)), axis=0, keepdims=True)
self.channel_relevance[channel] = low_pca_unc
print("End of fit.")
return high_res
def get_channel_quality(self, channel: str, low_res: Dict[str, np.ndarray], channel_pca_model: Dict[str, Pipeline]) -> float:
def get_channel_quality(self, channel: str, low_res_data: Dict[str, np.ndarray],
pca_model: PCA,
channel_relevance: Dict[str, float],
selection_model: SelectRelevantLowResolution,
channel_mean: Dict[str, np.ndarray]) -> float:
"""
Get the compatibility for a single channel.
......@@ -653,11 +659,15 @@ class Model(TransformerMixin, BaseEstimator):
Returns: the compatibility factor.
"""
pca_model = channel_pca_model[channel].named_steps['pca']
low_pca = pca_model.transform(low_res[channel])
# freeze input data in one channel only
low_res_data_frozen = {ch: low_res_data[ch] if ch != channel
else np.repeat(channel_mean[channel], low_res_data[ch].shape[0], axis=0)
for ch in low_res_data.keys()}
low_res_selected = selection_model.transform(low_res_data_frozen)
low_pca = pca_model.transform(low_res_selected)
low_pca_rec = pca_model.inverse_transform(low_pca)
low_pca_unc = channel_pca_model[channel].named_steps['unc'].uncertainty()
low_pca_dev = np.sqrt(np.mean((low_res[channel] - low_pca_rec)**2, axis=1, keepdims=True))
low_pca_unc = channel_relevance[channel]
low_pca_dev = np.sqrt(np.mean((low_res_selected - low_pca_rec)**2, axis=1, keepdims=True))
return low_pca_dev/low_pca_unc
def check_compatibility_per_channel(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
......@@ -671,12 +681,19 @@ class Model(TransformerMixin, BaseEstimator):
Returns: Ratio of root-mean-squared-error of the data reconstruction using the existing PCA model and the one from the original model per channel.
"""
selection_model = self.x_model['select']
low_res = selection_model.transform(low_res_data, keep_dictionary_structure=True)
quality = {channel: 0.0 for channel in low_res.keys()}
channels = list(low_res.keys())
#with mp.Pool(len(low_res.keys())) as p:
# values = p.map(partial(self.get_channel_quality, low_res=low_res, channel_pca_model=self.channel_pca_model), channels)
values = map(partial(self.get_channel_quality, low_res=low_res, channel_pca_model=self.channel_pca_model), channels)
quality = {channel: 0.0 for channel in low_res_data.keys()}
channels = list(low_res_data.keys())
pca_model = self.x_model['pca']
if 'fex' in self.x_model.named_steps:
pca_model = self.x_model['fex'].named_steps['prepca']
#with mp.Pool(len(low_res_data.keys())) as p:
values = map(partial(self.get_channel_quality,
low_res_data=low_res_data,
pca_model=pca_model,
channel_relevance=self.channel_relevance,
selection_model=selection_model,
channel_mean=self.channel_mean
), channels)
quality = dict(zip(channels, values))
return quality
......@@ -754,7 +771,8 @@ class Model(TransformerMixin, BaseEstimator):
joblib.dump([self.x_model,
self.y_model,
self.fit_model,
self.channel_pca_model
self.channel_mean,
self.channel_relevance
], filename, compress='zlib')
@staticmethod
......@@ -767,11 +785,12 @@ class Model(TransformerMixin, BaseEstimator):
Returns: A new model object.
"""
x_model, y_model, fit_model, channel_pca_model = joblib.load(filename)
x_model, y_model, fit_model, channel_mean, channel_relevance = joblib.load(filename)
obj = Model()
obj.x_model = x_model
obj.y_model = y_model
obj.fit_model = fit_model
obj.channel_pca_model = channel_pca_model
obj.channel_mean = channel_mean
obj.channel_relevance = channel_relevance
return obj
......@@ -198,6 +198,8 @@ def main():
start = time_ns()
rmse = model.check_compatibility(pes_raw_t)
print("Consistency check RMSE ratios:", rmse)
rmse = model.check_compatibility_per_channel(pes_raw_t)
print("Consistency per channel check RMSE ratios:", rmse)
t += [time_ns() - start]
t_names += ["Consistency"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment