From eba3fc961059a20cff8124f6bbda9946f42917b4 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Wed, 21 Dec 2022 16:50:50 +0100
Subject: [PATCH] Added data drift check

---
 pes_to_spec/model.py                 | 42 +++++++++++++++++++++++++++-
 pes_to_spec/test/offline_analysis.py | 10 +++++++
 2 files changed, 51 insertions(+), 1 deletion(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 67fd2d2..429a250 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -144,6 +144,7 @@ class Model(TransformerMixin, BaseEstimator):
 
         # PCA unc. in high resolution
         self.high_pca_unc: np.ndarray = np.zeros((1, 0), dtype=float)
+        self.low_pca_unc: np.ndarray = np.zeros((1, 0), dtype=float)
 
         # fit model
         self.fit_model = FitModel()
@@ -173,6 +174,7 @@ class Model(TransformerMixin, BaseEstimator):
                     delta_tof=self.delta_tof,
                     validation_size=self.validation_size,
                     high_pca_unc=self.high_pca_unc,
+                    low_pca_unc=self.low_pca_unc,
                     high_res_photon_energy=self.high_res_photon_energy,
                     )
 
@@ -310,8 +312,44 @@ class Model(TransformerMixin, BaseEstimator):
         high_pca_rec = self.hr_pca.inverse_transform(high_pca)
         self.high_pca_unc =  np.sqrt(np.mean((high_res - high_pca_rec)**2, axis=0, keepdims=True))
 
+        low_pca_rec = self.lr_pca.inverse_transform(low_pca)
+        self.low_pca_unc =  np.mean(np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)), axis=0, keepdims=True)
+
         return high_res
 
+    def check_compatibility(self, low_res_data: Dict[str, np.ndarray]) -> float:
+        """
+        Check if a new low-resolution data source is compatible with the one used in training, by
+        comparing the effect of the trained PCA model on it.
+
+        Args:
+          low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel).
+
+        Returns: Ratio of root-mean-squared-error of the data reconstruction using the existing PCA model and the one from the original model.
+        """
+        low_res = self.preprocess_low_res(low_res_data)
+        low_pca = self.lr_pca.transform(low_res)
+        low_pca_rec = self.lr_pca.inverse_transform(low_pca)
+
+        #fig = plt.figure(figsize=(8, 16))
+        #ax = plt.gca()
+        #ax.plot(low_res[0,...],
+        #        c="b",
+        #        label="LR")
+        #ax.plot(low_pca_rec[0,...],
+        #        c="r",
+        #        label="LR rec.")
+        #ax.set(title="",
+        #       xlabel="Photon Spectrometer channel",
+        #       ylabel="Low resolution spectrometer intensity")
+        #ax.legend()
+        #plt.savefig("check.png")
+        #plt.close(fig)
+
+        low_pca_unc =  np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True))
+        return low_pca_unc/self.low_pca_unc
+
+
     def predict(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
         """
         Predict a high-resolution spectrum from a low resolution given one.
@@ -329,7 +367,9 @@ class Model(TransformerMixin, BaseEstimator):
         # Get high res.
         high_pca = self.fit_model.predict(low_pca)
         n_trains = low_pca.shape[0]
-        pca_y = np.concatenate((high_pca["Y"], high_pca["Y"] + high_pca["Y_eps"]), axis=0)
+        pca_y = np.concatenate((high_pca["Y"],
+                                high_pca["Y"] + high_pca["Y_eps"]),
+                               axis=0)
         high_res_predicted = self.hr_pca.inverse_transform(pca_y)
         expected = high_res_predicted[:n_trains, :]
         unc = high_res_predicted[n_trains:, :] - expected
diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index 848ca88..b1a155c 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -106,6 +106,9 @@ def main():
     pes_raw = {ch: run['SA3_XTD10_PES/ADC/1:network',
                    f"digitizers.{ch}.raw.samples"].select_trains(by_id[tids]).ndarray()
                for ch in channels}
+    pes_raw_t = {ch: run['SA3_XTD10_PES/ADC/1:network',
+                   f"digitizers.{ch}.raw.samples"].select_trains(by_id[test_tids]).ndarray()
+               for ch in channels}
 
     # read the XGM information
     #xgm_pressure = run['SA3_XTD10_XGM/XGM/DOOCS', "pressure.pressureFiltered.value"].select_trains(by_id[tids]).ndarray()
@@ -151,6 +154,13 @@ def main():
     t += [time_ns() - start]
     t_names += ["Load"]
 
+    print("Check consistency")
+    start = time_ns()
+    rmse = model.check_compatibility(pes_raw_t)
+    print("Consistency check RMSE ratios:", rmse)
+    t += [time_ns() - start]
+    t_names += ["Consistency"]
+
     # test
     print("Predict")
     start = time_ns()
-- 
GitLab