From a1b2816a05a35e7b014bead4952ebdb712fa7080 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Wed, 13 Sep 2023 13:51:15 +0200
Subject: [PATCH] Fixed channel list in variance calculation.

---
 pes_to_spec/model.py | 16 +++++++++++++++-
 1 file changed, 15 insertions(+), 1 deletion(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 4d93ec8..5e72e2d 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -914,6 +914,19 @@ class Model(TransformerMixin, BaseEstimator):
         result = {ch: is_inlier(low_res_selected[ch], ch) for ch in channels}
         return result
 
+    def variance_per_channel(self) -> Dict[str, float]:
+        """
+        Check the total variance of the channel, to see if data is available.
+
+        Args:
+          low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
+          pulse_spacing: The pulse spacing in multi-pulse data.
+
+        Returns: Total variance per channel.
+        """
+        channels = list(low_res_data.keys())
+        return {ch: np.sum(self.channel_pca[ch].explained_variance_) for ch in channels}
+
     def check_compatibility(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: Optional[Dict[str, List[int]]]=None, pulse_energy: Optional[np.ndarray]=None) -> np.ndarray:
         """
         Check if a new low-resolution data source is compatible with the one used in training, by
@@ -992,7 +1005,8 @@ class Model(TransformerMixin, BaseEstimator):
         #print("Times")
         #print(dict(zip(n, t)))
 
-        return dict(expected=expected.reshape((B, P, -1)),
+        return dict(
+                    expected=expected.reshape((B, P, -1)),
                     unc=unc.reshape((B, P, -1)),
                     pca=pca_unc,
                     total_unc=total_unc.reshape((B, P, -1)),
-- 
GitLab