From 69633e3e26f382301ad4ec4d24c7730426357016 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Tue, 7 Nov 2023 15:23:54 +0100
Subject: [PATCH] Calculate SNR. Check if resolution fit fails and add nans.
 Regularize H calculation. Pay attention to deactivated channels when
 estimating pedestal.

---
 pes_to_spec/model.py | 102 ++++++++++++++++++++++++++++++-------------
 1 file changed, 71 insertions(+), 31 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 422a602..b76700b 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -60,22 +60,61 @@ def fwhm(x: np.ndarray, y: np.ndarray) -> float:
     right_idx = np.where(d < 0)[-1][-1]
     return x[right_idx] - x[left_idx]
 
-def deconv(y: np.ndarray, yhat: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+def deconv(y: np.ndarray, yhat: np.ndarray) -> Dict[str, Any]
     """Given the grating spectrometer data and the virtual spectrometer data,
     calculate the deconvolution between them.
     """
+    # calculate means
+    m_yhat = np.mean(yhat, keepdims=True, axis=0)
+    m_y = np.mean(y, keepdims=True, axis=0)
+
     # subtract the mean spectra to remove the FEL bandwidth
-    yhat_s = yhat - np.mean(yhat, keepdims=True, axis=0)
-    y_s = y  - np.mean(y, keepdims=True, axis=0)
+    y_s = y - m_y
+    yhat_s = yhat - m_yhat
+
+    # calculate normalization factor to set mean sum y^2 = 1
+    n_bins = float(y.shape[1])
+    A = np.mean(np.sum(y_s**2, axis=1), axis=0)
+    Ahat = np.mean(np.sum(yhat_s**2, axis=1), axis=0)
+
+    # sets mean sum y^2 = 1
+    y_s = y_s/np.sqrt(A)
+    yhat_s = yhat_s/np.sqrt(Ahat)
+
     # Fourier transforms
     Yhat = np.fft.fft(yhat_s)
     Y = np.fft.fft(y_s)
     # spectral power of the assumed "true" signal (the grating spectrometer data)
     Syy = np.mean(np.absolute(Y)**2, axis=0)
     Syh = np.mean(Y*np.conj(Yhat), axis=0)
+    beta = 1e-5
+    # regularization
+    Syh[np.absolute(Syh) < beta*np.amax(np.absolute(Syh))] = 0.0
     # approximate transfer function as the ratio of power spectrum densities
     H = Syh/Syy
-    return np.fft.fftshift(np.fft.ifft(H)), H, Syy
+
+    # calculate snr
+    H2 = np.absolute(H)**2
+    # inputs are normalized, so the normalization of h tells us how much signal there is
+    # Yhat = H Y + N
+    # mean sum |Yhat|^2 = mean sum |HY|^2 + mean sum H*Y*N + mean sum HYN* + mean sum |N|^2
+    # mean sum HYN* = sum H mean YN* = 0 because noise is uncorrelated
+    # mean sum |Yhat|^2 = sum |H|^2 mean |Y|^2 + mean sum |N|^2
+    # sum mean |N|^2 = n_bins sigma_n^2, if the noise is white
+    # Yhat and Y are normalized, so mean sum |Y|^2 = mean sum |Yhat|^2 = n_bins
+    # n_bins = sum |H|^2 mean |Y|^2 + n_bins sigma_n^2
+    # n_bins = sum |H|^2 Syy + n_bins sigma_n^2
+    # sigma_n^2 = 1 - sum |H|^2 Syy/n_bins
+    sigma_n = np.real(np.sqrt(1 - np.sum(H2*Syy)/n_bins))
+    sigma_s = np.real(np.sqrt(np.sum(H2*Syy)/n_bins))
+    snr = sigma_s/sigma_n
+
+    return dict(h=np.fft.fftshift(np.fft.ifft(H)),
+                H=H,
+                H2=H2,
+                Syy=Syy,
+                Syh=Syh,
+                snr=snr)
 
 def fit_gaussian(x: np.ndarray, y: np.ndarray) -> lmfit.ModelResult:
     """Fit Gaussian and return the fit result."""
@@ -86,7 +125,7 @@ def fit_gaussian(x: np.ndarray, y: np.ndarray) -> lmfit.ModelResult:
     return result
 
 def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray,
-                   e_center: Optional[float]=None, e_width: Optional[float]=None) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, lmfit.ModelResult]:
+        e_center: Optional[float]=None, e_width: Optional[float]=None) -> Dict[str, Any]:
     """
     Given the true y and the predicted y, together with the energy axis e,
     estimate the impulse response of the system and return the Gaussian fit to it.
@@ -99,7 +138,7 @@ def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray,
       e_center: If given the energy value, for which to probe the resolution.
       e_width: The width of the energy neighbourhood to probe if e_center is given.
 
-    Returns: The centered energy axis, the impulse response, the transfer function, the grating spectral density and the fit result.
+    Returns: A dictionary with the centered energy axis, the impulse response, the transfer function, the grating spectral density and the fit result.
     """
     e_range = e[-1] - e[0]
     e_axis = np.linspace(-0.5*e_range, 0.5*e_range, len(e))
@@ -111,8 +150,11 @@ def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray,
         f /= np.sum(f)
         y_sel = y_sel*f
         y_hat_sel = y_hat_sel*f
-    h, H, S = deconv(y_sel, y_hat_sel)
-    return e_axis, h, H, S, fit_gaussian(e_axis, np.absolute(h))
+    results = deconv(y_sel, y_hat_sel)
+    results["e_axis"] = e_axis
+    results["fit"] = fit_gaussian(e_axis, np.absolute(h))
+    results["fit_success"] = results["fit"].covar_ is not None
+    return results
 
 class PromptNotFoundError(Exception):
     """
@@ -380,7 +422,8 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
         Returns: The index.
         """
         # reduce on channel and on train ID
-        sum_low_res = np.mean(sum([-(v - self.pedestal[ch]) for ch, v in X.items()]), axis=0)
+        sum_low_res = np.mean(sum([-(X[ch] - self.pedestal[ch])
+                                   for ch in self.channels]), axis=0)
         # convert to Numpy if it is a Dask array
         if isinstance(sum_low_res, da.Array):
             sum_low_res = sum_low_res.compute()
@@ -797,7 +840,8 @@ class Model(TransformerMixin, BaseEstimator):
                               "pca_threshold",
                               "high_res_fwhm",
                               "resolution_per_energy", "resolution_per_energy_unc",
-                              "resolution_energy_bins"
+                              "resolution_energy_bins",
+                              "snr"
                               ]
 
     def n_pars(self) -> float:
@@ -1014,17 +1058,21 @@ class Model(TransformerMixin, BaseEstimator):
         e = high_res_photon_energy[0,:] if len(high_res_photon_energy.shape) == 2 else high_res_photon_energy
 
         # get average resolution
-        e_axis, h, H, S, result = get_resolution(y,
-                                                 y_hat,
-                                                 e
-                                                 )
+        result = get_resolution(y, y_hat, e)
+
+        self.resolution = np.exp(result["fit"].best_values["log_sigma"])*2.355 if result["fit_success"] else -1.0
+        self.snr = result["snr"]
+        print("Resolution = {self.resolution:.2f} eV, S/R = {self.snr:.2f}")
+
+        S = result["S"]
+        H = result["H"]
+        H2 = result["H2"]
 
         de = e[1] - e[0]
         E = np.fft.fftfreq(len(e), de)
 
         # noise spectral power
         N = np.mean(np.absolute(n)**2)
-        H2 = np.absolute(H)**2
         nonzero = np.absolute(H) > 0.2
         Hinv = (1.0/H)*nonzero
         # Wiener filter:
@@ -1057,15 +1105,6 @@ class Model(TransformerMixin, BaseEstimator):
         except:
             self.fwhm_virt = -1.0
 
-        #self.resolution = -1.0
-        #if self.fwhm_hr > 0 and self.fwhm_virt > self.fwhm_hr:
-        #    self.resolution = np.sqrt(self.fwhm_virt**2 - self.fwhm_hr**2)
-        #if self.resolution < 0:
-        #    print("Warning: Resolution calculation failed. The model can still be used, but this is probably a red flag!")
-        #
-        self.resolution = np.exp(result.best_values["log_sigma"])*2.355
-        print("Resolution:", self.resolution)
-
         # estimate resolution using power spectra
         e_min = e.min()
         e_max = e.max()
@@ -1074,13 +1113,14 @@ class Model(TransformerMixin, BaseEstimator):
         width = list()
         width_unc = list()
         for e_ in e_probe:
-            _, _, _, _, result = get_resolution(y,
-                                                y_hat,
-                                                e,
-                                                e_center=e_,
-                                                e_width=e_width)
-            width += [np.exp(result.best_values["log_sigma"])*2.355]
-            width_unc += [np.sqrt(result.covar[2,2])*2.355]
+            result = get_resolution(y, y_hat, e,
+                                    e_center=e_, e_width=e_width)
+            if result["fit_success"]:
+                width += [np.exp(result["fit"].best_values["log_sigma"])*2.355]
+                width_unc += [np.sqrt(result["fit"].covar[2,2])*2.355]
+            else:
+                width += [np.nan]
+                width_unc += [np.nan]
         width = np.array(width)
         width_unc = np.array(width_unc)
         self.resolution_per_energy = width
-- 
GitLab