From 1dfda8ed708923b63cc38e690b6821cbac657fc7 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Sat, 4 Feb 2023 18:03:30 +0100
Subject: [PATCH] Using channel relevance as a per channel indicator of
 compatibility.

---
 pes_to_spec/model.py                 | 69 ++++++++++++++++++----------
 pes_to_spec/test/offline_analysis.py |  2 +
 2 files changed, 46 insertions(+), 25 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 613bd3a..46b5fe7 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -545,10 +545,8 @@ class Model(TransformerMixin, BaseEstimator):
         #self.fit_model = FitModel()
         self.fit_model = MultiOutputWithStd(ARDRegression(n_iter=30, tol=1e-4, verbose=True))
 
-        self.channel_pca_model = {channel: Pipeline([('pca', PCA(n_pca_lr, whiten=True)),
-                                                    ('unc', UncertaintyHolder()),
-                                                    ])
-                                  for channel in channels}
+        self.channel_mean = {ch: np.nan for ch in channels}
+        self.channel_relevance = {ch: np.nan for ch in channels}
 
         # size of the test subset
         self.validation_size = validation_size
@@ -629,20 +627,28 @@ class Model(TransformerMixin, BaseEstimator):
         self.x_model['unc'].set_uncertainty(low_pca_unc)
 
         # for consistency check per channel
-        print("Calculate PCA per channel on low-resolution data.")
         selection_model = self.x_model['select']
-        low_res = selection_model.transform(low_res_data, keep_dictionary_structure=True)
         for channel in self.get_channels():
-            print(f"Calculate PCA on {channel}")
-            low_pca = self.channel_pca_model[channel].named_steps["pca"].fit_transform(low_res[channel])
-            low_pca_rec = self.channel_pca_model[channel].named_steps["pca"].inverse_transform(low_pca)
-            low_pca_unc =  np.mean(np.sqrt(np.mean((low_res[channel] - low_pca_rec)**2, axis=1, keepdims=True)), axis=0, keepdims=True)
-            self.channel_pca_model[channel]['unc'].set_uncertainty(low_pca_unc)
+            self.channel_mean[channel] = np.mean(low_res_data[channel], axis=0, keepdims=True)
+            print(f"Calculate PCA relevance on {channel}")
+            # freeze input data in one channel only
+            low_res_data_frozen = {ch: low_res_data[ch] if ch != channel
+                                       else np.repeat(self.channel_mean[channel], low_res_data[ch].shape[0], axis=0)
+                                   for ch in self.get_channels()}
+            low_res = selection_model.transform(low_res_data_frozen)
+            low_pca = pca_model.fit_transform(low_res)
+            low_pca_rec = pca_model.inverse_transform(low_pca)
+            low_pca_unc =  np.mean(np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)), axis=0, keepdims=True)
+            self.channel_relevance[channel] = low_pca_unc
         print("End of fit.")
 
         return high_res
 
-    def get_channel_quality(self, channel: str, low_res: Dict[str, np.ndarray], channel_pca_model: Dict[str, Pipeline]) -> float:
+    def get_channel_quality(self, channel: str, low_res_data: Dict[str, np.ndarray],
+                            pca_model: PCA,
+                            channel_relevance: Dict[str, float],
+                            selection_model: SelectRelevantLowResolution,
+                            channel_mean: Dict[str, np.ndarray]) -> float:
         """
         Get the compatibility for a single channel.
 
@@ -653,11 +659,15 @@ class Model(TransformerMixin, BaseEstimator):
 
         Returns: the compatibility factor.
         """
-        pca_model = channel_pca_model[channel].named_steps['pca']
-        low_pca = pca_model.transform(low_res[channel])
+        # freeze input data in one channel only
+        low_res_data_frozen = {ch: low_res_data[ch] if ch != channel
+                                   else np.repeat(channel_mean[channel], low_res_data[ch].shape[0], axis=0)
+                               for ch in low_res_data.keys()}
+        low_res_selected = selection_model.transform(low_res_data_frozen)
+        low_pca = pca_model.transform(low_res_selected)
         low_pca_rec = pca_model.inverse_transform(low_pca)
-        low_pca_unc = channel_pca_model[channel].named_steps['unc'].uncertainty()
-        low_pca_dev =  np.sqrt(np.mean((low_res[channel] - low_pca_rec)**2, axis=1, keepdims=True))
+        low_pca_unc = channel_relevance[channel]
+        low_pca_dev =  np.sqrt(np.mean((low_res_selected - low_pca_rec)**2, axis=1, keepdims=True))
         return low_pca_dev/low_pca_unc
 
     def check_compatibility_per_channel(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
@@ -671,12 +681,19 @@ class Model(TransformerMixin, BaseEstimator):
         Returns: Ratio of root-mean-squared-error of the data reconstruction using the existing PCA model and the one from the original model per channel.
         """
         selection_model = self.x_model['select']
-        low_res = selection_model.transform(low_res_data, keep_dictionary_structure=True)
-        quality = {channel: 0.0 for channel in low_res.keys()}
-        channels = list(low_res.keys())
-        #with mp.Pool(len(low_res.keys())) as p:
-        #    values = p.map(partial(self.get_channel_quality, low_res=low_res, channel_pca_model=self.channel_pca_model), channels)
-        values = map(partial(self.get_channel_quality, low_res=low_res, channel_pca_model=self.channel_pca_model), channels)
+        quality = {channel: 0.0 for channel in low_res_data.keys()}
+        channels = list(low_res_data.keys())
+        pca_model = self.x_model['pca']
+        if 'fex' in self.x_model.named_steps:
+            pca_model = self.x_model['fex'].named_steps['prepca']
+        #with mp.Pool(len(low_res_data.keys())) as p:
+        values = map(partial(self.get_channel_quality,
+                             low_res_data=low_res_data,
+                             pca_model=pca_model,
+                             channel_relevance=self.channel_relevance,
+                             selection_model=selection_model,
+                             channel_mean=self.channel_mean
+                            ), channels)
         quality = dict(zip(channels, values))
         return quality
 
@@ -754,7 +771,8 @@ class Model(TransformerMixin, BaseEstimator):
         joblib.dump([self.x_model,
                      self.y_model,
                      self.fit_model,
-                     self.channel_pca_model
+                     self.channel_mean,
+                     self.channel_relevance
                      ], filename, compress='zlib')
 
     @staticmethod
@@ -767,11 +785,12 @@ class Model(TransformerMixin, BaseEstimator):
 
         Returns: A new model object.
         """
-        x_model, y_model, fit_model, channel_pca_model = joblib.load(filename)
+        x_model, y_model, fit_model, channel_mean, channel_relevance = joblib.load(filename)
         obj = Model()
         obj.x_model = x_model
         obj.y_model = y_model
         obj.fit_model = fit_model
-        obj.channel_pca_model = channel_pca_model
+        obj.channel_mean = channel_mean
+        obj.channel_relevance = channel_relevance
         return obj
 
diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index 782aab6..6db5415 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -198,6 +198,8 @@ def main():
     start = time_ns()
     rmse = model.check_compatibility(pes_raw_t)
     print("Consistency check RMSE ratios:", rmse)
+    rmse = model.check_compatibility_per_channel(pes_raw_t)
+    print("Consistency per channel check RMSE ratios:", rmse)
     t += [time_ns() - start]
     t_names += ["Consistency"]
 
-- 
GitLab