From a1b2816a05a35e7b014bead4952ebdb712fa7080 Mon Sep 17 00:00:00 2001 From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de> Date: Wed, 13 Sep 2023 13:51:15 +0200 Subject: [PATCH] Fixed channel list in variance calculation. --- pes_to_spec/model.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 4d93ec8..5e72e2d 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -914,6 +914,19 @@ class Model(TransformerMixin, BaseEstimator): result = {ch: is_inlier(low_res_selected[ch], ch) for ch in channels} return result + def variance_per_channel(self) -> Dict[str, float]: + """ + Check the total variance of the channel, to see if data is available. + + Args: + low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel). + pulse_spacing: The pulse spacing in multi-pulse data. + + Returns: Total variance per channel. + """ + channels = list(low_res_data.keys()) + return {ch: np.sum(self.channel_pca[ch].explained_variance_) for ch in channels} + def check_compatibility(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: Optional[Dict[str, List[int]]]=None, pulse_energy: Optional[np.ndarray]=None) -> np.ndarray: """ Check if a new low-resolution data source is compatible with the one used in training, by @@ -992,7 +1005,8 @@ class Model(TransformerMixin, BaseEstimator): #print("Times") #print(dict(zip(n, t))) - return dict(expected=expected.reshape((B, P, -1)), + return dict( + expected=expected.reshape((B, P, -1)), unc=unc.reshape((B, P, -1)), pca=pca_unc, total_unc=total_unc.reshape((B, P, -1)), -- GitLab