From eba3fc961059a20cff8124f6bbda9946f42917b4 Mon Sep 17 00:00:00 2001 From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de> Date: Wed, 21 Dec 2022 16:50:50 +0100 Subject: [PATCH] Added data drift check --- pes_to_spec/model.py | 42 +++++++++++++++++++++++++++- pes_to_spec/test/offline_analysis.py | 10 +++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 67fd2d2..429a250 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -144,6 +144,7 @@ class Model(TransformerMixin, BaseEstimator): # PCA unc. in high resolution self.high_pca_unc: np.ndarray = np.zeros((1, 0), dtype=float) + self.low_pca_unc: np.ndarray = np.zeros((1, 0), dtype=float) # fit model self.fit_model = FitModel() @@ -173,6 +174,7 @@ class Model(TransformerMixin, BaseEstimator): delta_tof=self.delta_tof, validation_size=self.validation_size, high_pca_unc=self.high_pca_unc, + low_pca_unc=self.low_pca_unc, high_res_photon_energy=self.high_res_photon_energy, ) @@ -310,8 +312,44 @@ class Model(TransformerMixin, BaseEstimator): high_pca_rec = self.hr_pca.inverse_transform(high_pca) self.high_pca_unc = np.sqrt(np.mean((high_res - high_pca_rec)**2, axis=0, keepdims=True)) + low_pca_rec = self.lr_pca.inverse_transform(low_pca) + self.low_pca_unc = np.mean(np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)), axis=0, keepdims=True) + return high_res + def check_compatibility(self, low_res_data: Dict[str, np.ndarray]) -> float: + """ + Check if a new low-resolution data source is compatible with the one used in training, by + comparing the effect of the trained PCA model on it. + + Args: + low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel). + + Returns: Ratio of root-mean-squared-error of the data reconstruction using the existing PCA model and the one from the original model. + """ + low_res = self.preprocess_low_res(low_res_data) + low_pca = self.lr_pca.transform(low_res) + low_pca_rec = self.lr_pca.inverse_transform(low_pca) + + #fig = plt.figure(figsize=(8, 16)) + #ax = plt.gca() + #ax.plot(low_res[0,...], + # c="b", + # label="LR") + #ax.plot(low_pca_rec[0,...], + # c="r", + # label="LR rec.") + #ax.set(title="", + # xlabel="Photon Spectrometer channel", + # ylabel="Low resolution spectrometer intensity") + #ax.legend() + #plt.savefig("check.png") + #plt.close(fig) + + low_pca_unc = np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)) + return low_pca_unc/self.low_pca_unc + + def predict(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: """ Predict a high-resolution spectrum from a low resolution given one. @@ -329,7 +367,9 @@ class Model(TransformerMixin, BaseEstimator): # Get high res. high_pca = self.fit_model.predict(low_pca) n_trains = low_pca.shape[0] - pca_y = np.concatenate((high_pca["Y"], high_pca["Y"] + high_pca["Y_eps"]), axis=0) + pca_y = np.concatenate((high_pca["Y"], + high_pca["Y"] + high_pca["Y_eps"]), + axis=0) high_res_predicted = self.hr_pca.inverse_transform(pca_y) expected = high_res_predicted[:n_trains, :] unc = high_res_predicted[n_trains:, :] - expected diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py index 848ca88..b1a155c 100755 --- a/pes_to_spec/test/offline_analysis.py +++ b/pes_to_spec/test/offline_analysis.py @@ -106,6 +106,9 @@ def main(): pes_raw = {ch: run['SA3_XTD10_PES/ADC/1:network', f"digitizers.{ch}.raw.samples"].select_trains(by_id[tids]).ndarray() for ch in channels} + pes_raw_t = {ch: run['SA3_XTD10_PES/ADC/1:network', + f"digitizers.{ch}.raw.samples"].select_trains(by_id[test_tids]).ndarray() + for ch in channels} # read the XGM information #xgm_pressure = run['SA3_XTD10_XGM/XGM/DOOCS', "pressure.pressureFiltered.value"].select_trains(by_id[tids]).ndarray() @@ -151,6 +154,13 @@ def main(): t += [time_ns() - start] t_names += ["Load"] + print("Check consistency") + start = time_ns() + rmse = model.check_compatibility(pes_raw_t) + print("Consistency check RMSE ratios:", rmse) + t += [time_ns() - start] + t_names += ["Consistency"] + # test print("Predict") start = time_ns() -- GitLab