Skip to content
Snippets Groups Projects

Calculate SNR and some bug fixes

Merged Danilo Enoque Ferreira de Lima requested to merge snr into main
1 file
+ 71
31
Compare changes
  • Side-by-side
  • Inline
+ 71
31
@@ -60,22 +60,61 @@ def fwhm(x: np.ndarray, y: np.ndarray) -> float:
right_idx = np.where(d < 0)[-1][-1]
return x[right_idx] - x[left_idx]
def deconv(y: np.ndarray, yhat: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
def deconv(y: np.ndarray, yhat: np.ndarray) -> Dict[str, Any]
"""Given the grating spectrometer data and the virtual spectrometer data,
calculate the deconvolution between them.
"""
# calculate means
m_yhat = np.mean(yhat, keepdims=True, axis=0)
m_y = np.mean(y, keepdims=True, axis=0)
# subtract the mean spectra to remove the FEL bandwidth
yhat_s = yhat - np.mean(yhat, keepdims=True, axis=0)
y_s = y - np.mean(y, keepdims=True, axis=0)
y_s = y - m_y
yhat_s = yhat - m_yhat
# calculate normalization factor to set mean sum y^2 = 1
n_bins = float(y.shape[1])
A = np.mean(np.sum(y_s**2, axis=1), axis=0)
Ahat = np.mean(np.sum(yhat_s**2, axis=1), axis=0)
# sets mean sum y^2 = 1
y_s = y_s/np.sqrt(A)
yhat_s = yhat_s/np.sqrt(Ahat)
# Fourier transforms
Yhat = np.fft.fft(yhat_s)
Y = np.fft.fft(y_s)
# spectral power of the assumed "true" signal (the grating spectrometer data)
Syy = np.mean(np.absolute(Y)**2, axis=0)
Syh = np.mean(Y*np.conj(Yhat), axis=0)
beta = 1e-5
# regularization
Syh[np.absolute(Syh) < beta*np.amax(np.absolute(Syh))] = 0.0
# approximate transfer function as the ratio of power spectrum densities
H = Syh/Syy
return np.fft.fftshift(np.fft.ifft(H)), H, Syy
# calculate snr
H2 = np.absolute(H)**2
# inputs are normalized, so the normalization of h tells us how much signal there is
# Yhat = H Y + N
# mean sum |Yhat|^2 = mean sum |HY|^2 + mean sum H*Y*N + mean sum HYN* + mean sum |N|^2
# mean sum HYN* = sum H mean YN* = 0 because noise is uncorrelated
# mean sum |Yhat|^2 = sum |H|^2 mean |Y|^2 + mean sum |N|^2
# sum mean |N|^2 = n_bins sigma_n^2, if the noise is white
# Yhat and Y are normalized, so mean sum |Y|^2 = mean sum |Yhat|^2 = n_bins
# n_bins = sum |H|^2 mean |Y|^2 + n_bins sigma_n^2
# n_bins = sum |H|^2 Syy + n_bins sigma_n^2
# sigma_n^2 = 1 - sum |H|^2 Syy/n_bins
sigma_n = np.real(np.sqrt(1 - np.sum(H2*Syy)/n_bins))
sigma_s = np.real(np.sqrt(np.sum(H2*Syy)/n_bins))
snr = sigma_s/sigma_n
return dict(h=np.fft.fftshift(np.fft.ifft(H)),
H=H,
H2=H2,
Syy=Syy,
Syh=Syh,
snr=snr)
def fit_gaussian(x: np.ndarray, y: np.ndarray) -> lmfit.ModelResult:
"""Fit Gaussian and return the fit result."""
@@ -86,7 +125,7 @@ def fit_gaussian(x: np.ndarray, y: np.ndarray) -> lmfit.ModelResult:
return result
def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray,
e_center: Optional[float]=None, e_width: Optional[float]=None) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, lmfit.ModelResult]:
e_center: Optional[float]=None, e_width: Optional[float]=None) -> Dict[str, Any]:
"""
Given the true y and the predicted y, together with the energy axis e,
estimate the impulse response of the system and return the Gaussian fit to it.
@@ -99,7 +138,7 @@ def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray,
e_center: If given the energy value, for which to probe the resolution.
e_width: The width of the energy neighbourhood to probe if e_center is given.
Returns: The centered energy axis, the impulse response, the transfer function, the grating spectral density and the fit result.
Returns: A dictionary with the centered energy axis, the impulse response, the transfer function, the grating spectral density and the fit result.
"""
e_range = e[-1] - e[0]
e_axis = np.linspace(-0.5*e_range, 0.5*e_range, len(e))
@@ -111,8 +150,11 @@ def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray,
f /= np.sum(f)
y_sel = y_sel*f
y_hat_sel = y_hat_sel*f
h, H, S = deconv(y_sel, y_hat_sel)
return e_axis, h, H, S, fit_gaussian(e_axis, np.absolute(h))
results = deconv(y_sel, y_hat_sel)
results["e_axis"] = e_axis
results["fit"] = fit_gaussian(e_axis, np.absolute(h))
results["fit_success"] = results["fit"].covar_ is not None
return results
class PromptNotFoundError(Exception):
"""
@@ -380,7 +422,8 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
Returns: The index.
"""
# reduce on channel and on train ID
sum_low_res = np.mean(sum([-(v - self.pedestal[ch]) for ch, v in X.items()]), axis=0)
sum_low_res = np.mean(sum([-(X[ch] - self.pedestal[ch])
for ch in self.channels]), axis=0)
# convert to Numpy if it is a Dask array
if isinstance(sum_low_res, da.Array):
sum_low_res = sum_low_res.compute()
@@ -797,7 +840,8 @@ class Model(TransformerMixin, BaseEstimator):
"pca_threshold",
"high_res_fwhm",
"resolution_per_energy", "resolution_per_energy_unc",
"resolution_energy_bins"
"resolution_energy_bins",
"snr"
]
def n_pars(self) -> float:
@@ -1014,17 +1058,21 @@ class Model(TransformerMixin, BaseEstimator):
e = high_res_photon_energy[0,:] if len(high_res_photon_energy.shape) == 2 else high_res_photon_energy
# get average resolution
e_axis, h, H, S, result = get_resolution(y,
y_hat,
e
)
result = get_resolution(y, y_hat, e)
self.resolution = np.exp(result["fit"].best_values["log_sigma"])*2.355 if result["fit_success"] else -1.0
self.snr = result["snr"]
print("Resolution = {self.resolution:.2f} eV, S/R = {self.snr:.2f}")
S = result["S"]
H = result["H"]
H2 = result["H2"]
de = e[1] - e[0]
E = np.fft.fftfreq(len(e), de)
# noise spectral power
N = np.mean(np.absolute(n)**2)
H2 = np.absolute(H)**2
nonzero = np.absolute(H) > 0.2
Hinv = (1.0/H)*nonzero
# Wiener filter:
@@ -1057,15 +1105,6 @@ class Model(TransformerMixin, BaseEstimator):
except:
self.fwhm_virt = -1.0
#self.resolution = -1.0
#if self.fwhm_hr > 0 and self.fwhm_virt > self.fwhm_hr:
# self.resolution = np.sqrt(self.fwhm_virt**2 - self.fwhm_hr**2)
#if self.resolution < 0:
# print("Warning: Resolution calculation failed. The model can still be used, but this is probably a red flag!")
#
self.resolution = np.exp(result.best_values["log_sigma"])*2.355
print("Resolution:", self.resolution)
# estimate resolution using power spectra
e_min = e.min()
e_max = e.max()
@@ -1074,13 +1113,14 @@ class Model(TransformerMixin, BaseEstimator):
width = list()
width_unc = list()
for e_ in e_probe:
_, _, _, _, result = get_resolution(y,
y_hat,
e,
e_center=e_,
e_width=e_width)
width += [np.exp(result.best_values["log_sigma"])*2.355]
width_unc += [np.sqrt(result.covar[2,2])*2.355]
result = get_resolution(y, y_hat, e,
e_center=e_, e_width=e_width)
if result["fit_success"]:
width += [np.exp(result["fit"].best_values["log_sigma"])*2.355]
width_unc += [np.sqrt(result["fit"].covar[2,2])*2.355]
else:
width += [np.nan]
width_unc += [np.nan]
width = np.array(width)
width_unc = np.array(width_unc)
self.resolution_per_energy = width
Loading