diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 78d944c2b323e7481c11a1d78bbdddd88c039f32..7cd34852939327276b9377d533c5edc0874408c1 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -829,29 +829,94 @@ class Model(TransformerMixin, BaseEstimator): # e: energy # true signal (as far as we can get -- it is smoothened, but this is the model target) d = high_res[inliers] + D = np.fft.fft(d) + y_pred, n = self.fit_model.predict(x_t[inliers], return_std=True) z = self.y_model['pca'].inverse_transform(y_pred) - n = np.fabs(self.y_model['pca'].inverse_transform(y_pred + n) - z) + + #n = np.sqrt((self.y_model['pca'].inverse_transform(y_pred + n) - z)**2 + high_pca_unc**2) e = high_res_photon_energy[0,:] if len(high_res_photon_energy.shape) == 2 else high_res_photon_energy - D = np.fft.fft(d) + Z = np.fft.fft(z) - V = np.fft.fft(n) + #V = np.fft.fft(np.mean(n, axis=0)) + de = e[1] - e[0] E = np.fft.fftfreq(len(e), de) e_axis = np.linspace(-0.5*(e[-1] - e[0]), 0.5*(e[-1] - e[0]), len(e)) - H = np.mean(Z/D, axis=0) - N = np.mean(np.absolute(V)**2, axis=0) - S = np.mean(np.absolute(D)**2, axis=0) - # Wiener filter: - G = np.conjugate(H) * S / (np.absolute(H)**2 * S + N) # generate a gaussian gaussian = np.exp(-0.5*(e_axis)**2/self.high_res_sigma**2) gaussian /= np.sum(gaussian, axis=0, keepdims=True) - G *= gaussian - self.wiener_filter = np.fft.ifft(G) + gaussian = np.clip(gaussian, a_min=1e-6, a_max=None) + gaussian_ft = np.fft.fft(gaussian) + + H = np.mean(Z/D, axis=0) + N = np.absolute(gaussian_ft)**2 + S = np.mean(np.absolute(D)**2, axis=0) + H2 = np.absolute(H)**2 + nonzero = np.absolute(H) > 0.2 + Hinv = (1.0/H)*nonzero + # Wiener filter: + G = Hinv * (H2 * S) / (H2 * S + N) + + #import matplotlib.pyplot as plt + #from matplotlib.gridspec import GridSpec + #fig = plt.figure(figsize=(40, 40)) + #gs = GridSpec(2, 2) + #ax = fig.add_subplot(gs[0, 0]) + #ax.plot(np.fft.fftshift(np.mean(np.absolute(Z), axis=0)), c='b', lw=3, label="Prediction") + #ax.plot(np.fft.fftshift(np.mean(np.absolute(D), axis=0)), c='r', lw=3, label="True") + #ax.legend() + #ax.set(title=f"", + # xlabel="Reciprocal energy [1/eV]", + # ylabel="Intensity [a.u.]", + # yscale='log', + # ) + #ax = fig.add_subplot(gs[0, 1]) + #ax.plot(np.fft.fftshift(np.absolute(H)**2), c='b', lw=3, label=r"$|H|^2$") + #ax.plot(np.fft.fftshift(np.absolute(Hinv)), c='k', lw=3, label=r"$|H^{-1}|$") + #ax.plot(np.fft.fftshift(N), c='g', lw=3, label=r"$N$") + #ax.plot(np.fft.fftshift(S), c='r', lw=3, label=r"$S$") + #ax.legend() + #ax.set(title=f"", + # xlabel="Reciprocal energy [1/eV]", + # ylabel="Intensity [a.u.]", + # yscale='log', + # ) + #ax = fig.add_subplot(gs[1, 0]) + #ax.plot(np.fft.fftshift(np.absolute(H)), c='b', lw=3, label="H") + #ax.plot(np.fft.fftshift(np.absolute(np.mean(Z, axis=0)/np.mean(D, axis=0))), c='r', lw=3, label="mean Z/mean D") + #ax.legend() + #ax.set(title=f"", + # xlabel="Reciprocal energy [1/eV]", + # ylabel="Intensity [a.u.]", + # yscale="log", + # ) + #ax = fig.add_subplot(gs[1, 1]) + #ax.plot(np.fft.fftshift(np.absolute(G)), c='b', lw=3, label="H") + #ax.plot(np.fft.fftshift(np.absolute(np.mean(Z, axis=0)/np.mean(D, axis=0))), c='r', lw=3, label="mean Z/mean D") + #ax.legend() + #ax.set(title=f"", + # xlabel="Reciprocal energy [1/eV]", + # ylabel="Intensity [a.u.]", + # yscale="log", + # ) + #fig.savefig("tmp.png") + #plt.close(fig) + + Hmod = np.real(np.absolute(H)) + Gdir = np.fft.fftshift(np.fft.ifft(G)) + self.wiener_filter = Gdir self.wiener_filter_ft = G self.wiener_energy = e_axis self.wiener_energy_ft = E + self.transfer_function = H + h = np.fft.fftshift(np.fft.ifft(H)) + hmod = np.real(np.absolute(h)) + self.impulse_response = h + energy_mu = np.sum(e_axis*hmod)/np.sum(hmod) + energy_var = np.sum(((e_axis - energy_mu)**2)*hmod)/np.sum(hmod) + self.resolution = np.sqrt(energy_var) + print("Resolution:", self.resolution) # get intensity effect intensity = np.sum(z, axis=1) @@ -945,10 +1010,19 @@ class Model(TransformerMixin, BaseEstimator): pca_unc = self.y_model['unc'].uncertainty() total_unc = np.sqrt(pca_unc**2 + unc**2) + M = self.wiener_filter.shape[0] + B = expected.shape[0] + assert expected.shape[1] == M + deconvolved = fftconvolve(expected, + np.broadcast_to(self.wiener_filter.reshape(1, -1), (B, M)), + mode="same", + axes=1) + return dict(expected=expected, unc=unc, pca=pca_unc, total_unc=total_unc, + deconvolved=deconvolved, Z_intensity=Z_intensity ) @@ -973,6 +1047,9 @@ class Model(TransformerMixin, BaseEstimator): wiener_filter=self.wiener_filter, wiener_energy=self.wiener_energy, wiener_energy_ft=self.wiener_energy_ft, + resolution=self.resolution, + transfer_function=self.transfer_function, + impulse_response=self.impulse_response, ) ), self.ood, @@ -1019,5 +1096,8 @@ class Model(TransformerMixin, BaseEstimator): obj.wiener_filter = extra["wiener_filter"] obj.wiener_energy_ft = extra["wiener_energy_ft"] obj.wiener_energy = extra["wiener_energy"] + obj.resolution = extra["resolution"] + obj.transfer_function = extra["transfer_function"] + obj.impulse_response = extra["impulse_response"] return obj diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py index 62be08a5f19e0d654aeed5fa1c51312bee08cd9d..0abf2249411e6c2c59748a1691f59ca0e0e68185 100755 --- a/pes_to_spec/test/offline_analysis.py +++ b/pes_to_spec/test/offline_analysis.py @@ -98,9 +98,9 @@ def plot_result(filename: str, #ax.fill_between(spec_raw_pe, spec_pred["expected"] - unc_pca, spec_pred["expected"] + unc_pca, facecolor='magenta', alpha=0.6, label="68% unc. (syst., PCA)") #if spec_raw_int is not None: # ax.plot(spec_raw_pe, spec_raw_int, c='b', lw=1, ls='--', label="High-resolution measurement") - if wiener is not None: - deconvolved = fftconvolve(spec_pred["expected"], wiener, mode="same") - ax.plot(spec_raw_pe, deconvolved, c='g', ls='-.', lw=3, label="Wiener filter result") + #if wiener is not None: + # deconvolved = fftconvolve(spec_pred["expected"], wiener, mode="same") + ax.plot(spec_raw_pe, spec_pred["deconvolved"], c='g', ls='-.', lw=3, label="Wiener filter result") Y = np.amax(spec_smooth) ax.legend(frameon=False, borderaxespad=0, loc='upper left') ax.set(title=f"", #avg(stat unc) = {unc_stat}, avg(pca unc) = {unc_pca}", @@ -262,7 +262,7 @@ def main(): fig = plt.figure(figsize=(12, 8)) gs = GridSpec(1, 1) ax = fig.add_subplot(gs[0, 0]) - plt.plot(model.wiener_energy, np.fft.fftshift(np.absolute(model.wiener_filter))) + plt.plot(model.wiener_energy, np.absolute(model.wiener_filter)) ax.set(title=f"", xlabel=r"Energy [eV]", ylabel="Filter value [a.u.]",