diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 4d93ec82f883e304963e96c4c698b2cd100c2a1d..5e72e2d8fe03ee47466dc44d68cf8f19d98d92de 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -914,6 +914,19 @@ class Model(TransformerMixin, BaseEstimator): result = {ch: is_inlier(low_res_selected[ch], ch) for ch in channels} return result + def variance_per_channel(self) -> Dict[str, float]: + """ + Check the total variance of the channel, to see if data is available. + + Args: + low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel). + pulse_spacing: The pulse spacing in multi-pulse data. + + Returns: Total variance per channel. + """ + channels = list(low_res_data.keys()) + return {ch: np.sum(self.channel_pca[ch].explained_variance_) for ch in channels} + def check_compatibility(self, low_res_data: Dict[str, np.ndarray], pulse_spacing: Optional[Dict[str, List[int]]]=None, pulse_energy: Optional[np.ndarray]=None) -> np.ndarray: """ Check if a new low-resolution data source is compatible with the one used in training, by @@ -992,7 +1005,8 @@ class Model(TransformerMixin, BaseEstimator): #print("Times") #print(dict(zip(n, t))) - return dict(expected=expected.reshape((B, P, -1)), + return dict( + expected=expected.reshape((B, P, -1)), unc=unc.reshape((B, P, -1)), pca=pca_unc, total_unc=total_unc.reshape((B, P, -1)),