From 10406ed50852718ed28577dffb2c2ded90c96173 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Fri, 16 Dec 2022 10:14:59 +0100
Subject: [PATCH] Moved more parameters to the constructor.

---
 pes_to_spec/model.py     | 21 +++++++++++++++------
 scripts/test_analysis.py |  2 +-
 2 files changed, 16 insertions(+), 7 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index e271aab..9a45f68 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -24,6 +24,10 @@ class Model(object):
       channels: Selected channels to use as an input for the low resolution data.
       n_pca_lr: Number of low-resolution data PCA components.
       n_pca_hr: Number of high-resolution data PCA components.
+      high_res_sigma: Resolution of the high-resolution spectrometer.
+      tof_start: Start looking at this index from the low-resolution spectrometer data.
+      delta_tof: Number of components to take from the low-resolution spectrometer.
+      validation_size: Fraction (number between 0 and 1) of the data to take for validation and systematic uncertainty estimate.
 
     """
     def __init__(self,
@@ -34,7 +38,11 @@ class Model(object):
                                      "channel_4_C",
                                      "channel_4_D"],
                  n_pca_lr: int=400,
-                 n_pca_hr: int=20):
+                 n_pca_hr: int=20,
+                 high_res_sigma: float=0.2,
+                 tof_start: int=31445,
+                 delta_tof: int=200,
+                 validation_size: float=0.05):
         self.channels = channels
         self.n_pca_lr = n_pca_lr
         self.n_pca_hr = n_pca_hr
@@ -50,18 +58,18 @@ class Model(object):
         self.fit_model = FitModel()
 
         # size of the test subset
-        self.test_size = 0.05
+        self.validation_size = validation_size
 
         # where to cut on the ToF PES data
-        self.tof_start = 31445
-        self.delta_tof = 200
+        self.tof_start = tof_start
+        self.delta_tof = delta_tof
         self.tof_end = self.tof_start + self.delta_tof
 
         # high-resolution photon energy axis
         self.high_res_photon_energy: Optional[np.ndarray] = None
 
         # smoothing of the SPEC data in eV
-        self.high_res_sigma = 0.2
+        self.high_res_sigma = high_res_sigma
 
     def preprocess_low_res(self, low_res_data: Dict[str, np.ndarray]) -> np.ndarray:
         """
@@ -90,6 +98,7 @@ class Model(object):
         n_features = high_res_data.shape[1]
         mu = high_res_photon_energy[0, n_features//2]
         gaussian = np.exp(-((high_res_photon_energy - mu)/self.high_res_sigma)**2/2)/np.sqrt(2*np.pi*self.high_res_sigma**2)
+        print(np.sum(gaussian))
         # 80 to match normalization (empirically taken)
         high_res_gc = fftconvolve(high_res_data, gaussian, mode="same", axes=1)/80.0
         return high_res_gc
@@ -114,7 +123,7 @@ class Model(object):
         low_pca = self.lr_pca.fit_transform(low_res)
         high_pca = self.hr_pca.fit_transform(high_res)
         # split in train and test for PCA uncertainty evaluation
-        low_pca_train, low_pca_test, high_pca_train, high_pca_test = train_test_split(low_pca, high_pca, test_size=self.test_size, random_state=42)
+        low_pca_train, low_pca_test, high_pca_train, high_pca_test = train_test_split(low_pca, high_pca, test_size=self.validation_size, random_state=42)
         # fit the linear model
         self.fit_model.fit(low_pca_train, high_pca_train, low_pca_test, high_pca_test)
 
diff --git a/scripts/test_analysis.py b/scripts/test_analysis.py
index 53e074d..0a6bc59 100755
--- a/scripts/test_analysis.py
+++ b/scripts/test_analysis.py
@@ -48,7 +48,7 @@ def main():
     run_dir = "/gpfs/exfel/exp/SA3/202121/p002935/raw"
     run_id = "r0015"
     # get run
-    run = RunDirectory(f"{run_dir}/{run_id}") 
+    run = RunDirectory(f"{run_dir}/{run_id}")
 
     # get train IDs and match them, so we are sure to have information from all needed sources
     # in this example, there is an offset of -2 in the SPEC train ID, so correct for it
-- 
GitLab