diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index f4b677f1edb2b56ae9edec21715638b15ced3ffa..b95641de1a0445226822663d9c5dc1e2b4bf95cc 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -8,6 +8,7 @@ import dask.array as da import numpy as np import scipy +import lmfit from scipy.signal import fftconvolve from sklearn.covariance import EllipticEnvelope from sklearn.decomposition import IncrementalPCA, PCA @@ -76,6 +77,42 @@ def deconv(y: np.ndarray, yhat: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np. H = Syh/Syy return np.fft.fftshift(np.fft.ifft(H)), H, Syy +def fit_gaussian(x: np.ndarray, y: np.ndarray) -> lmfit.ModelResult: + """Fit Gaussian and return the fit result.""" + def gaussian(x, amplitude, centre, sigma): + return amplitude * np.exp(-0.5 * (x - centre)**2 / (sigma**2)) + gmodel = lmfit.Model(gaussian) + result = gmodel.fit(y, x=x, centre=0.0, amplitude=np.amax(y), sigma=1.0) + return result + +def get_resolution(y: np.ndarray, y_hat: np.ndarray, e: np.ndarray, + e_center: Optional[float]=None, e_width: Optional[float]=None) -> Tuple[np.ndarray, np.ndarray, lmfit.ModelResult]: + """ + Given the true y and the predicted y, together with the energy axis e, + estimate the impulse response of the system and return the Gaussian fit to it. + If e_center and e_width are given, multiply the spectra by a box function with given parameters before the resolution estimate. + + Args: + y: The true spectrum. Shape (N, K). + y_hat: The predicted spectrum. Shape (N, K). + e: The energy axis. Shape (K,). + e_center: If given the energy value, for which to probe the resolution. + e_width: The width of the energy neighbourhood to probe if e_center is given. + + Returns: The centered energy axis, the impulse response and the fit result. + """ + e_range = e[-1] - e[0] + e_axis = np.linspace(-0.5*e_range, 0.5*e_range, len(e)) + y_sel = y + y_hat_sel = y_hat + if e_center is not None and e_width is not None: + #f = ((e > e_center - e_width*0.5) & (e < e_center + e_width*0.5)).astype(np.float32) + f = np.exp(-0.5 * (e - e_center)**2 / (e_width**2)) + f /= np.sum(f) + y_sel = y_sel*f + y_hat_sel = y_hat_sel*f + h, H, S = deconv(y_sel, y_hat_sel) + return e_axis, h, fit_gaussian(e_axis, np.absolute(h)) class PromptNotFoundError(Exception): """