Skip to content
Snippets Groups Projects
Commit 69633e3e authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Calculate SNR. Check if resolution fit fails and add nans. Regularize H...

Calculate SNR. Check if resolution fit fails and add nans. Regularize H calculation. Pay attention to deactivated channels when estimating pedestal.
parent 17fbd15f
No related branches found
No related tags found
1 merge request!21Calculate SNR and some bug fixes
This commit is part of merge request !21. Comments created here will be created in the context of that merge request.
...@@ -60,22 +60,61 @@ def fwhm(x: np.ndarray, y: np.ndarray) -> float: ...@@ -60,22 +60,61 @@ def fwhm(x: np.ndarray, y: np.ndarray) -> float:
right_idx = np.where(d < 0)[-1][-1] right_idx = np.where(d < 0)[-1][-1]
return x[right_idx] - x[left_idx] return x[right_idx] - x[left_idx]
def deconv(y: np.ndarray, yhat: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: def deconv(y: np.ndarray, yhat: np.ndarray) -> Dict[str, Any]
"""Given the grating spectrometer data and the virtual spectrometer data, """Given the grating spectrometer data and the virtual spectrometer data,
calculate the deconvolution between them. calculate the deconvolution between them.
""" """
# calculate means
m_yhat = np.mean(yhat, keepdims=True, axis=0)
m_y = np.mean(y, keepdims=True, axis=0)
# subtract the mean spectra to remove the FEL bandwidth # subtract the mean spectra to remove the FEL bandwidth
yhat_s = yhat - np.mean(yhat, keepdims=True, axis=0) y_s = y - m_y
y_s = y - np.mean(y, keepdims=True, axis=0) yhat_s = yhat - m_yhat
# calculate normalization factor to set mean sum y^2 = 1
n_bins = float(y.shape[1])
A = np.mean(np.sum(y_s**2, axis=1), axis=0)
Ahat = np.mean(np.sum(yhat_s**2, axis=1), axis=0)
# sets mean sum y^2 = 1
y_s = y_s/np.sqrt(A)
yhat_s = yhat_s/np.sqrt(Ahat)
# Fourier transforms # Fourier transforms
Yhat = np.fft.fft(yhat_s) Yhat = np.fft.fft(yhat_s)
Y = np.fft.fft(y_s) Y = np.fft.fft(y_s)
# spectral power of the assumed "true" signal (the grating spectrometer data) # spectral power of the assumed "true" signal (the grating spectrometer data)
Syy = np.mean(np.absolute(Y)**2, axis=0) Syy = np.mean(np.absolute(Y)**2, axis=0)
Syh = np.mean(Y*np.conj(Yhat), axis=0) Syh = np.mean(Y*np.conj(Yhat), axis=0)
beta = 1e-5
# regularization
Syh[np.absolute(Syh) < beta*np.amax(np.absolute(Syh))] = 0.0
# approximate transfer function as the ratio of power spectrum densities # approximate transfer function as the ratio of power spectrum densities
H = Syh/Syy H = Syh/Syy
return np.fft.fftshift(np.fft.ifft(H)), H, Syy
# calculate snr
H2 = np.absolute(H)**2
# inputs are normalized, so the normalization of h tells us how much signal there is
# Yhat = H Y + N
# mean sum |Yhat|^2 = mean sum |HY|^2 + mean sum H*Y*N + mean sum HYN* + mean sum |N|^2
# mean sum HYN* = sum H mean YN* = 0 because noise is uncorrelated
# mean sum |Yhat|^2 = sum |H|^2 mean |Y|^2 + mean sum |N|^2
# sum mean |N|^2 = n_bins sigma_n^2, if the noise is white
# Yhat and Y are normalized, so mean sum |Y|^2 = mean sum |Yhat|^2 = n_bins
# n_bins = sum |H|^2 mean |Y|^2 + n_bins sigma_n^2
# n_bins = sum |H|^2 Syy + n_bins sigma_n^2
# sigma_n^2 = 1 - sum |H|^2 Syy/n_bins
sigma_n = np.real(np.sqrt(1 - np.sum(H2*Syy)/n_bins))
sigma_s = np.real(np.sqrt(np.sum(H2*Syy)/n_bins))
snr = sigma_s/sigma_n
return dict(h=np.fft.fftshift(np.fft.ifft(H)),
H=H,
H2=H2,
Syy=Syy,
Syh=Syh,
snr=snr)
def fit_gaussian(x: np.ndarray, y: np.ndarray) -> lmfit.ModelResult: def fit_gaussian(x: np.ndarray, y: np.ndarray) -> lmfit.ModelResult:
"""Fit Gaussian and return the fit result.""" """Fit Gaussian and return the fit result."""
...@@ -86,7 +125,7 @@ def fit_gaussian(x: np.ndarray, y: np.ndarray) -> lmfit.ModelResult: ...@@ -86,7 +125,7 @@ def fit_gaussian(x: np.ndarray, y: np.ndarray) -> lmfit.ModelResult:
return result return result
def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray, def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray,
e_center: Optional[float]=None, e_width: Optional[float]=None) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, lmfit.ModelResult]: e_center: Optional[float]=None, e_width: Optional[float]=None) -> Dict[str, Any]:
""" """
Given the true y and the predicted y, together with the energy axis e, Given the true y and the predicted y, together with the energy axis e,
estimate the impulse response of the system and return the Gaussian fit to it. estimate the impulse response of the system and return the Gaussian fit to it.
...@@ -99,7 +138,7 @@ def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray, ...@@ -99,7 +138,7 @@ def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray,
e_center: If given the energy value, for which to probe the resolution. e_center: If given the energy value, for which to probe the resolution.
e_width: The width of the energy neighbourhood to probe if e_center is given. e_width: The width of the energy neighbourhood to probe if e_center is given.
Returns: The centered energy axis, the impulse response, the transfer function, the grating spectral density and the fit result. Returns: A dictionary with the centered energy axis, the impulse response, the transfer function, the grating spectral density and the fit result.
""" """
e_range = e[-1] - e[0] e_range = e[-1] - e[0]
e_axis = np.linspace(-0.5*e_range, 0.5*e_range, len(e)) e_axis = np.linspace(-0.5*e_range, 0.5*e_range, len(e))
...@@ -111,8 +150,11 @@ def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray, ...@@ -111,8 +150,11 @@ def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray,
f /= np.sum(f) f /= np.sum(f)
y_sel = y_sel*f y_sel = y_sel*f
y_hat_sel = y_hat_sel*f y_hat_sel = y_hat_sel*f
h, H, S = deconv(y_sel, y_hat_sel) results = deconv(y_sel, y_hat_sel)
return e_axis, h, H, S, fit_gaussian(e_axis, np.absolute(h)) results["e_axis"] = e_axis
results["fit"] = fit_gaussian(e_axis, np.absolute(h))
results["fit_success"] = results["fit"].covar_ is not None
return results
class PromptNotFoundError(Exception): class PromptNotFoundError(Exception):
""" """
...@@ -380,7 +422,8 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator): ...@@ -380,7 +422,8 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
Returns: The index. Returns: The index.
""" """
# reduce on channel and on train ID # reduce on channel and on train ID
sum_low_res = np.mean(sum([-(v - self.pedestal[ch]) for ch, v in X.items()]), axis=0) sum_low_res = np.mean(sum([-(X[ch] - self.pedestal[ch])
for ch in self.channels]), axis=0)
# convert to Numpy if it is a Dask array # convert to Numpy if it is a Dask array
if isinstance(sum_low_res, da.Array): if isinstance(sum_low_res, da.Array):
sum_low_res = sum_low_res.compute() sum_low_res = sum_low_res.compute()
...@@ -797,7 +840,8 @@ class Model(TransformerMixin, BaseEstimator): ...@@ -797,7 +840,8 @@ class Model(TransformerMixin, BaseEstimator):
"pca_threshold", "pca_threshold",
"high_res_fwhm", "high_res_fwhm",
"resolution_per_energy", "resolution_per_energy_unc", "resolution_per_energy", "resolution_per_energy_unc",
"resolution_energy_bins" "resolution_energy_bins",
"snr"
] ]
def n_pars(self) -> float: def n_pars(self) -> float:
...@@ -1014,17 +1058,21 @@ class Model(TransformerMixin, BaseEstimator): ...@@ -1014,17 +1058,21 @@ class Model(TransformerMixin, BaseEstimator):
e = high_res_photon_energy[0,:] if len(high_res_photon_energy.shape) == 2 else high_res_photon_energy e = high_res_photon_energy[0,:] if len(high_res_photon_energy.shape) == 2 else high_res_photon_energy
# get average resolution # get average resolution
e_axis, h, H, S, result = get_resolution(y, result = get_resolution(y, y_hat, e)
y_hat,
e self.resolution = np.exp(result["fit"].best_values["log_sigma"])*2.355 if result["fit_success"] else -1.0
) self.snr = result["snr"]
print("Resolution = {self.resolution:.2f} eV, S/R = {self.snr:.2f}")
S = result["S"]
H = result["H"]
H2 = result["H2"]
de = e[1] - e[0] de = e[1] - e[0]
E = np.fft.fftfreq(len(e), de) E = np.fft.fftfreq(len(e), de)
# noise spectral power # noise spectral power
N = np.mean(np.absolute(n)**2) N = np.mean(np.absolute(n)**2)
H2 = np.absolute(H)**2
nonzero = np.absolute(H) > 0.2 nonzero = np.absolute(H) > 0.2
Hinv = (1.0/H)*nonzero Hinv = (1.0/H)*nonzero
# Wiener filter: # Wiener filter:
...@@ -1057,15 +1105,6 @@ class Model(TransformerMixin, BaseEstimator): ...@@ -1057,15 +1105,6 @@ class Model(TransformerMixin, BaseEstimator):
except: except:
self.fwhm_virt = -1.0 self.fwhm_virt = -1.0
#self.resolution = -1.0
#if self.fwhm_hr > 0 and self.fwhm_virt > self.fwhm_hr:
# self.resolution = np.sqrt(self.fwhm_virt**2 - self.fwhm_hr**2)
#if self.resolution < 0:
# print("Warning: Resolution calculation failed. The model can still be used, but this is probably a red flag!")
#
self.resolution = np.exp(result.best_values["log_sigma"])*2.355
print("Resolution:", self.resolution)
# estimate resolution using power spectra # estimate resolution using power spectra
e_min = e.min() e_min = e.min()
e_max = e.max() e_max = e.max()
...@@ -1074,13 +1113,14 @@ class Model(TransformerMixin, BaseEstimator): ...@@ -1074,13 +1113,14 @@ class Model(TransformerMixin, BaseEstimator):
width = list() width = list()
width_unc = list() width_unc = list()
for e_ in e_probe: for e_ in e_probe:
_, _, _, _, result = get_resolution(y, result = get_resolution(y, y_hat, e,
y_hat, e_center=e_, e_width=e_width)
e, if result["fit_success"]:
e_center=e_, width += [np.exp(result["fit"].best_values["log_sigma"])*2.355]
e_width=e_width) width_unc += [np.sqrt(result["fit"].covar[2,2])*2.355]
width += [np.exp(result.best_values["log_sigma"])*2.355] else:
width_unc += [np.sqrt(result.covar[2,2])*2.355] width += [np.nan]
width_unc += [np.nan]
width = np.array(width) width = np.array(width)
width_unc = np.array(width_unc) width_unc = np.array(width_unc)
self.resolution_per_energy = width self.resolution_per_energy = width
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment