From d179e5c75e95bf9200c0932678391f6cfb42699a Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Thu, 19 Oct 2023 14:14:28 +0200
Subject: [PATCH] Fit resolution per energy bin.

---
 pes_to_spec/model.py | 64 +++++++++++++++++++++++++++++++-------------
 1 file changed, 46 insertions(+), 18 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index b95641d..44d7a14 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -86,7 +86,7 @@ def fit_gaussian(x: np.ndarray, y: np.ndarray) -> lmfit.ModelResult:
     return result
 
 def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray,
-                   e_center: Optional[float]=None, e_width: Optional[float]=None) -> Tuple[np.ndarray, np.ndarray, lmfit.ModelResult]:
+                   e_center: Optional[float]=None, e_width: Optional[float]=None) -> Tuple[np.ndarray, np.ndarray, np.ndarray, lmfit.ModelResult]:
     """
     Given the true y and the predicted y, together with the energy axis e,
     estimate the impulse response of the system and return the Gaussian fit to it.
@@ -99,7 +99,7 @@ def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray,
       e_center: If given the energy value, for which to probe the resolution.
       e_width: The width of the energy neighbourhood to probe if e_center is given.
 
-    Returns: The centered energy axis, the impulse response and the fit result.
+    Returns: The centered energy axis, the impulse response, the transfer function and the fit result.
     """
     e_range = e[-1] - e[0]
     e_axis = np.linspace(-0.5*e_range, 0.5*e_range, len(e))
@@ -112,7 +112,7 @@ def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray,
         y_sel = y_sel*f
         y_hat_sel = y_hat_sel*f
     h, H, S = deconv(y_sel, y_hat_sel)
-    return e_axis, h, fit_gaussian(e_axis, np.absolute(h))
+    return e_axis, h, H, fit_gaussian(e_axis, np.absolute(h))
 
 class PromptNotFoundError(Exception):
     """
@@ -782,6 +782,10 @@ class Model(TransformerMixin, BaseEstimator):
         self.transfer_function = None
         self.impulse_response = None
 
+        self.resolution_per_energy = None
+        self.resolution_per_energy_unc = None
+        self.resolution_energy_bins = None
+
         self.extra_options = ["mu_xgm", "sigma_xgm",
                               "wiener_filter_ft", "wiener_filter",
                               "wiener_energy_ft", "wiener_energy",
@@ -792,6 +796,8 @@ class Model(TransformerMixin, BaseEstimator):
                               "n_obs",
                               "pca_threshold",
                               "high_res_fwhm",
+                              "resolution_per_energy", "resolution_per_energy_unc"
+                              "resolution_energy_bins"
                               ]
 
     def n_pars(self) -> float:
@@ -1007,13 +1013,15 @@ class Model(TransformerMixin, BaseEstimator):
         n = np.sqrt((self.y_model['pca'].inverse_transform(y_pred + n) - y_hat)**2 + high_pca_unc**2)
         e = high_res_photon_energy[0,:] if len(high_res_photon_energy.shape) == 2 else high_res_photon_energy
 
+        # get average resolution
+        e_axis, h, H, result = get_resolution(y,
+                                              y_hat,
+                                              e
+                                              )
+
         de = e[1] - e[0]
         E = np.fft.fftfreq(len(e), de)
-        e_range = e[-1] - e[0]
-        e_axis = np.linspace(-0.5*e_range, 0.5*e_range, len(e))
 
-        # transfer function estimate, signal spectral power
-        h, H, S = deconv(y, y_hat)
         # noise spectral power
         N = np.mean(np.absolute(n)**2)
         H2 = np.absolute(H)**2
@@ -1021,17 +1029,15 @@ class Model(TransformerMixin, BaseEstimator):
         Hinv = (1.0/H)*nonzero
         # Wiener filter:
         G = Hinv * (H2 * S) / (H2 * S + N)
-
         Gdir = np.fft.fftshift(np.fft.ifft(G))
         self.wiener_filter = Gdir
         self.wiener_filter_ft = G
         self.wiener_energy = e_axis
         self.wiener_energy_ft = E
         self.transfer_function = H
-        h = np.fft.fftshift(np.fft.ifft(H))
         self.impulse_response = h
 
-        # get grating spec. resolution
+        # get grating spec. autocorr. width
         mean_y = np.mean(y, keepdims=True, axis=0)
         self.auto_corr_hr = np.mean(np.fft.fftshift(np.fft.ifft(np.absolute(np.fft.fft(y - mean_y))**2), axes=(-1,)), axis=0)
         self.auto_corr_hr = np.real(self.auto_corr_hr)
@@ -1041,7 +1047,7 @@ class Model(TransformerMixin, BaseEstimator):
         except:
             self.fwhm_hr = -1.0
 
-        # get virtual spectrometer resolution
+        # get virtual spectrometer autocorr. width
         mean_y_hat = np.mean(y_hat, keepdims=True, axis=0)
         self.auto_corr_virt = np.mean(np.fft.fftshift(np.fft.ifft(np.absolute(np.fft.fft(y_hat - mean_y_hat))**2), axes=(-1,)), axis=0)
         self.auto_corr_virt = np.real(self.auto_corr_virt)
@@ -1051,13 +1057,35 @@ class Model(TransformerMixin, BaseEstimator):
         except:
             self.fwhm_virt = -1.0
 
-        self.resolution = -1.0
-        if self.fwhm_hr > 0 and self.fwhm_virt > self.fwhm_hr:
-            self.resolution = np.sqrt(self.fwhm_virt**2 - self.fwhm_hr**2)
-        if self.resolution < 0:
-            print("Warning: Resolution calculation failed. The model can still be used, but this is probably a red flag!")
-        else:
-            print("Resolution:", self.resolution)
+        #self.resolution = -1.0
+        #if self.fwhm_hr > 0 and self.fwhm_virt > self.fwhm_hr:
+        #    self.resolution = np.sqrt(self.fwhm_virt**2 - self.fwhm_hr**2)
+        #if self.resolution < 0:
+        #    print("Warning: Resolution calculation failed. The model can still be used, but this is probably a red flag!")
+        #
+        self.resolution = result.best_values["sigma"]*2.355
+        print("Resolution:", self.resolution)
+
+        # estimate resolution using power spectra
+        e_min = e.min()
+        e_max = e.max()
+        e_probe = np.linspace(e_min, e_max, 5)
+        e_width = (e_max - e_min)/(len(e_probe[key])-1)
+        width = list()
+        width_unc = list()
+        for e_ in e_probe:
+            e_axis_centered, h, H, result = get_resolution(y,
+                                                           y_hat,
+                                                           e_axis,
+                                                           e_center=e_,
+                                                           e_width=e_width)
+            width += [result.best_values["sigma"]*2.355]
+            width_unc += [np.sqrt(result.covar[2,2])*2.355]
+        width = np.array(width)
+        width_unc = np.array(width_unc)
+        self.resolution_per_energy = width
+        self.resolution_per_energy_unc = width_unc
+        self.resolution_energy_bins = e_probe
 
         # this speeds things up considerably, when we do not care about that
         if not ood:
-- 
GitLab