diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 54b8d4a1ae0f99626a115f860e0ac087fc0c990f..79f7449264d35d07f5c8133b5e0319e8320bc716 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -531,7 +531,7 @@ class Model(TransformerMixin, BaseEstimator): return high_res - def check_compatibility(self, low_res_data: Dict[str, np.ndarray]) -> float: + def check_compatibility(self, low_res_data: Dict[str, np.ndarray]) -> np.ndarray: """ Check if a new low-resolution data source is compatible with the one used in training, by comparing the effect of the trained PCA model on it. @@ -561,8 +561,8 @@ class Model(TransformerMixin, BaseEstimator): #plt.savefig("check.png") #plt.close(fig) - low_pca_unc = np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)) - return low_pca_unc/low_pca_unc + low_pca_dev = np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)) + return low_pca_dev/low_pca_unc def predict(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: