diff --git a/README.md b/README.md index 123f9b3194d25b83b1dabb760e8fef8f87182d4f..bc8544882d9dc1db6687820ee67b63c46ef30961 100644 --- a/README.md +++ b/README.md @@ -32,12 +32,11 @@ model.fit(low_resolution_raw_data, high_resolution_photon_energy) # save it for later usage: -model.save("model.h5") +model.save("model.joblib") # when performing inference: # load a model: -model = Model() -model.load("model.h5") +model = Model.load("model.joblib") # and use it to map a low-resolution spectrum to a high-resolution one # as before, the low_resolution_raw_data refers to a dictionary mapping the channel name diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 429a250c003a3ce1eca5f921d2985d0e651f936e..d33806a6986f01cd7d3a1bd3cd4f6d601d98ff74 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -6,11 +6,16 @@ from scipy.signal import fftconvolve from scipy.signal import find_peaks_cwt from scipy.optimize import fmin_l_bfgs_b from sklearn.decomposition import PCA, IncrementalPCA -from sklearn.model_selection import train_test_split +from sklearn.pipeline import Pipeline from sklearn.base import TransformerMixin, BaseEstimator +from sklearn.base import RegressorMixin +from sklearn.compose import TransformedTargetRegressor from itertools import product +from sklearn.model_selection import train_test_split from time import time_ns +import joblib + import matplotlib.pyplot as plt from typing import Union, Any, Dict, List, Optional @@ -107,89 +112,146 @@ class PromptNotFoundError(Exception): def __str__(self) -> str: return "No prompt peak has been detected." -class Model(TransformerMixin, BaseEstimator): +class HighResolutionSmoother(TransformerMixin, BaseEstimator): """ - Object representing a previous fit of the model to be used to predict high-resolution - spectrum from a low-resolution one. + Smoothens out the high resolution data. + + Args: + high_res_sigma: Energy resolution in eV. + """ + def __init__(self, + high_res_sigma: float=0.2 + ): + self.high_res_sigma = high_res_sigma + self.energy = None + + def fit(self, X, y=None, **fit_params) -> TransformerMixin: + """ + Fit records the energy axis. + + Args: + X: Irrelevant. + y: Irrelevant. + fit_params: Contains the energy axis in the key "energy" with shape (any, features). + + Returns: The object itself. + """ + self.energy = fit_params["energy"] + if len(self.energy.shape) == 2: + self.energy = self.energy[0,:] + return self + + def transform(self, X: np.ndarray) -> np.ndarray: + """ + Apply smoothing in X using the energy axis. + + Args: + X: Input to smoothen with shape (train id, features). + + Returns: Smoothened out spectrum. + """ + # use a default energy axis is none is given + # assume only the energy step + energy = np.broadcast_to(self.energy, X.shape) + + # Apply smoothing + n_features = X.shape[1] + # get the centre value of the energy axis + mu = energy[:, n_features//2, np.newaxis] + # generate a gaussian + gaussian = np.exp(-0.5*(energy - mu)**2/self.high_res_sigma**2) + gaussian /= np.sum(gaussian, axis=1, keepdims=True) + # apply it to the data + high_res_gc = fftconvolve(X, gaussian, mode="same", axes=1) + return high_res_gc + + +class UncertaintyHolder(TransformerMixin, BaseEstimator): + """ + Keep track of uncertainties. + + """ + def __init__(self): + self.unc: np.ndarray = np.zeros((1, 0), dtype=float) + + def set_uncertainty(self, unc: np.ndarray): + """ + Set the uncertainty. + + Args: + unc: The uncertainty. + """ + self.unc = np.copy(unc) + + def fit(self, X, y=None) -> TransformerMixin: + """ + Does nothing. + + Args: + X: Irrelevant. + y: Irrelevant. + + Returns: Itself. + """ + return self + + def transform(self, X: np.ndarray) -> np.ndarray: + """ + Identity map. + + Args: + X: The input. + """ + return X + + def inverse_transform(self, X: np.ndarray) -> np.ndarray: + """ + Identity map. + + Args: + X: The input. + """ + return X + + def uncertainty(self): + """The uncertainty recorded.""" + return self.unc + +class SelectRelevantLowResolution(TransformerMixin, BaseEstimator): + """ + Select only relevant entries in the low-resolution data. Args: channels: Selected channels to use as an input for the low resolution data. - n_pca_lr: Number of low-resolution data PCA components. - n_pca_hr: Number of high-resolution data PCA components. - high_res_sigma: Resolution of the high-resolution spectrometer in electron-Volts. tof_start: Start looking at this index from the low-resolution spectrometer data. Set to None to perform no selection delta_tof: Number of components to take from the low-resolution spectrometer. Set to None to perform no selection. - validation_size: Fraction (number between 0 and 1) of the data to take for - validation and systematic uncertainty estimate. - """ def __init__(self, channels:List[str]=[f"channel_{j}_{k}" for j, k in product(range(1, 5), ["A", "B", "C", "D"])], - n_pca_lr: int=600, - n_pca_hr: int=20, - high_res_sigma: float=0.2, tof_start: Optional[int]=None, delta_tof: Optional[int]=300, - validation_size: float=0.05): + ): self.channels = channels - self.n_pca_lr = n_pca_lr - self.n_pca_hr = n_pca_hr - - # PCA models - self.lr_pca = PCA(n_pca_lr, whiten=True) - self.hr_pca = PCA(n_pca_hr, whiten=True) - - # PCA unc. in high resolution - self.high_pca_unc: np.ndarray = np.zeros((1, 0), dtype=float) - self.low_pca_unc: np.ndarray = np.zeros((1, 0), dtype=float) - - # fit model - self.fit_model = FitModel() - - # size of the test subset - self.validation_size = validation_size - - # where to cut on the ToF PES data self.tof_start = tof_start self.delta_tof = delta_tof - # high-resolution photon energy axis - self.high_res_photon_energy: Optional[np.ndarray] = None - - # smoothing of the SPEC data in eV - self.high_res_sigma = high_res_sigma - - def parameters(self) -> Dict[str, Any]: - """ - Dump parameters as a dictionary. - """ - return dict(channels=self.channels, - n_pca_lr=self.n_pca_lr, - n_pca_hr=self.n_pca_hr, - high_res_sigma=self.high_res_sigma, - tof_start=self.tof_start, - delta_tof=self.delta_tof, - validation_size=self.validation_size, - high_pca_unc=self.high_pca_unc, - low_pca_unc=self.low_pca_unc, - high_res_photon_energy=self.high_res_photon_energy, - ) - - def preprocess_low_res(self, low_res_data: Dict[str, np.ndarray]) -> np.ndarray: + def transform(self, X: Dict[str, np.ndarray]) -> np.ndarray: """ Get a dictionary with the channel names for the inut low resolution data and output only the relevant input data in an array. Args: - low_res_data: Dictionary with keys named channel_{i}_{k}, - where i is a number between 1 and 4 and k is a letter between A and D. + X: Dictionary with keys named channel_{i}_{k}, + where i is a number between 1 and 4 and k is a letter between A and D. Returns: Concatenated and pre-processed low-resolution data of shape (train_id, features). """ - items = [low_res_data[k] for k in self.channels] + if self.tof_start is None: + raise NotImplementedError("The low-resolution data cannot be transformed before the prompt has been identified. Call the fit function first.") + items = [X[k] for k in self.channels] if self.delta_tof is not None: items = [item[:, self.tof_start:(self.tof_start + self.delta_tof)] for item in items] else: @@ -197,37 +259,17 @@ class Model(TransformerMixin, BaseEstimator): cat = np.concatenate(items, axis=1) return cat - def preprocess_high_res(self, high_res_data: np.ndarray, high_res_photon_energy: np.ndarray) -> np.ndarray: - """ - Get the high resolution data and preprocess it. - - Args: - high_res_data: High resolution data with shape (train_id, features). - high_res_photon_energy: High resolution photon energy values - (the "x"-axis of the high resolution data) with - shape (train_id, features). - - Returns: Pre-processed high-resolution data of shape (train_id, features) before. - """ - # Apply smoothing - n_features = high_res_data.shape[1] - mu = high_res_photon_energy[:, n_features//2, np.newaxis] - gaussian = np.exp(-0.5*(high_res_photon_energy - mu)**2/self.high_res_sigma**2) - gaussian /= np.sum(gaussian, axis=1, keepdims=True) - high_res_gc = fftconvolve(high_res_data, gaussian, mode="same", axes=1) - return high_res_gc - - def estimate_prompt_peak(self, low_res_data: Dict[str, np.ndarray]) -> int: + def estimate_prompt_peak(self, X: Dict[str, np.ndarray]) -> int: """ Estimate the prompt peak index. Args: - low_res_data: Low resolution data with a dictionary containing the channel names. + X: Low resolution data with a dictionary containing the channel names. - Returns: The prompt peak index. + Returns: The index. """ # reduce on channel and on train ID - sum_low_res = - np.mean(sum(list(low_res_data.values())), axis=0) + sum_low_res = - np.mean(sum(list(X.values())), axis=0) widths = np.arange(10, 50, step=1) peak_idx = find_peaks_cwt(sum_low_res, widths) if len(peak_idx) < 1: @@ -242,17 +284,30 @@ class Model(TransformerMixin, BaseEstimator): improved_guess = min_search + int(np.argmax(restricted_arr)) return improved_guess - def debug_peak_finding(self, low_res_data: Dict[str, np.ndarray], filename: str): + def fit(self, X: Dict[str, np.ndarray], y: Optional[Any]=None) -> TransformerMixin: + """ + Estimate the prompt peak index. + + Args: + X: Low resolution data with a dictionary containing the channel names. + y: Ignored. + + Returns: The object itself. + """ + self.tof_start = self.estimate_prompt_peak(X) + return self + + def debug_peak_finding(self, X: Dict[str, np.ndarray], filename: str): """ Produce image to understand if the peak finding step worked well. Args: - low_res_data: Low resolution data with a dictionary containing the channel names. + X: Low resolution data with a dictionary containing the channel names. filename: The file name where to save the plot. """ - sum_low_res = - np.mean(sum(list(low_res_data.values())), axis=0) - peak_idx = self.estimate_prompt_peak(low_res_data) + sum_low_res = - np.mean(sum(list(X.values())), axis=0) + peak_idx = self.estimate_prompt_peak(X) fig = plt.figure(figsize=(8, 16)) ax = plt.gca() ax.plot(np.arange(peak_idx-100, peak_idx+300), @@ -271,178 +326,7 @@ class Model(TransformerMixin, BaseEstimator): plt.savefig(filename) plt.close(fig) - def fit(self, low_res_data: Dict[str, np.ndarray], high_res_data: np.ndarray, high_res_photon_energy: np.ndarray) -> np.ndarray: - """ - Train the model. - - Args: - low_res_data: Low resolution data as a dictionary with the key set to `channel_{i}_{k}`, - where i is a number between 1 and 4 and k is a letter between A and D. - For each dictionary entry, a numpy array is expected with shape - (train_id, ToF channel). - high_res_data: Reference high resolution data with a one-to-one match to the - low resolution data in the train_id dimension. Shape (train_id, ToF channel). - high_res_photon_energy: Photon energy axis for the high-resolution data. - - Returns: Smoothened high resolution spectrum. - """ - - self.high_res_photon_energy = high_res_photon_energy[0, np.newaxis, :] - - # if the prompt peak has not been given, guess it - if self.tof_start is None: - self.tof_start = self.estimate_prompt_peak(low_res_data) - - low_res = self.preprocess_low_res(low_res_data) - high_res = self.preprocess_high_res(high_res_data, high_res_photon_energy) - # fit PCA - low_pca = self.lr_pca.fit_transform(low_res) - high_pca = self.hr_pca.fit_transform(high_res) - # split in train and test for PCA uncertainty evaluation - (low_pca_train, low_pca_test, - high_pca_train, high_pca_test) = train_test_split(low_pca, high_pca, - test_size=self.validation_size, - random_state=42) - # fit the linear model - self.fit_model.fit(low_pca_train, - high_pca_train, - low_pca_test, - high_pca_test) - - high_pca_rec = self.hr_pca.inverse_transform(high_pca) - self.high_pca_unc = np.sqrt(np.mean((high_res - high_pca_rec)**2, axis=0, keepdims=True)) - - low_pca_rec = self.lr_pca.inverse_transform(low_pca) - self.low_pca_unc = np.mean(np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)), axis=0, keepdims=True) - - return high_res - - def check_compatibility(self, low_res_data: Dict[str, np.ndarray]) -> float: - """ - Check if a new low-resolution data source is compatible with the one used in training, by - comparing the effect of the trained PCA model on it. - - Args: - low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel). - - Returns: Ratio of root-mean-squared-error of the data reconstruction using the existing PCA model and the one from the original model. - """ - low_res = self.preprocess_low_res(low_res_data) - low_pca = self.lr_pca.transform(low_res) - low_pca_rec = self.lr_pca.inverse_transform(low_pca) - - #fig = plt.figure(figsize=(8, 16)) - #ax = plt.gca() - #ax.plot(low_res[0,...], - # c="b", - # label="LR") - #ax.plot(low_pca_rec[0,...], - # c="r", - # label="LR rec.") - #ax.set(title="", - # xlabel="Photon Spectrometer channel", - # ylabel="Low resolution spectrometer intensity") - #ax.legend() - #plt.savefig("check.png") - #plt.close(fig) - - low_pca_unc = np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)) - return low_pca_unc/self.low_pca_unc - - - def predict(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: - """ - Predict a high-resolution spectrum from a low resolution given one. - The output includes the uncertainty in its second and third entries of the first dimension. - - Args: - low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel). - - Returns: High resolution data with shape (train_id, energy channel) in a dictionary containing - the expected prediction in key "expected", the stat. uncertainty in key "unc" and - a (1, energy channel) array for the PCA syst. uncertainty in key "pca". - """ - low_res = self.preprocess_low_res(low_res_data) - low_pca = self.lr_pca.transform(low_res) - # Get high res. - high_pca = self.fit_model.predict(low_pca) - n_trains = low_pca.shape[0] - pca_y = np.concatenate((high_pca["Y"], - high_pca["Y"] + high_pca["Y_eps"]), - axis=0) - high_res_predicted = self.hr_pca.inverse_transform(pca_y) - expected = high_res_predicted[:n_trains, :] - unc = high_res_predicted[n_trains:, :] - expected - return dict(expected=expected, - unc=unc, - pca=self.high_pca_unc) - - def save(self, filename: str): - """ - Save the fit model in a file. - - Args: - filename: H5 file name where to save this. - """ - #joblib.dump(self, filename) - with h5py.File(filename, 'w') as hf: - # transform parameters into a dict - d = self.fit_model.as_dict() - d.update(self.parameters()) - # dump them in the file - dump_in_group(d, hf) - # this is not ideal, because it depends on the knowledge of the PCA - # object structure, but saving to a joblib file would mean creating several - # files - # create a group - lr_pca = hf.create_group("lr_pca") - # get PCA properties - lr_pca_props = get_pca_props(self.lr_pca) - # create the HR group - hr_pca = hf.create_group("hr_pca") - # get PCA properties - hr_pca_props = get_pca_props(self.hr_pca) - # dump them - dump_in_group(lr_pca_props, lr_pca) - dump_in_group(hr_pca_props, hr_pca) - - - def load(self, filename: str): - """ - Load model from a file. - - Args: - filename: Name of the file where to read the model from. - - """ - with h5py.File(filename, 'r') as hf: - # read from file - d = read_from_group(hf) - # load fit_model parameters - self.fit_model.from_dict(d) - # load parameters of this class - for key in self.parameters().keys(): - value = d[key] - if key == 'channels': - value = [item.decode() if isinstance(item, bytes) - else item - for item in value] - setattr(self, key, value) - # this is not ideal, because it depends on the knowledge of the PCA - # object structure, but saving to a joblib file would mean creating several - # files - lr_pca = hf["/lr_pca/"] - hr_pca = hf["/hr_pca/"] - self.lr_pca = PCA(self.n_pca_lr, whiten=True) - self.hr_pca = PCA(self.n_pca_hr, whiten=True) - # read properties in dictionaries - lr_pca_props = read_from_group(lr_pca) - hr_pca_props = read_from_group(hr_pca) - # set them - self.lr_pca = set_pca_props(self.lr_pca, lr_pca_props) - self.hr_pca = set_pca_props(self.hr_pca, hr_pca_props) - -class FitModel(object): +class FitModel(RegressorMixin, BaseEstimator): """ Linear regression model with uncertainties. """ @@ -462,14 +346,27 @@ class FitModel(object): self.input_data = None - def fit(self, X_train: np.ndarray, Y_train: np.ndarray, X_test: np.ndarray, Y_test: np.ndarray): + def fit(self, X: np.ndarray, y: np.ndarray, **fit_params) -> RegressorMixin: """ Perform the fit and evaluate uncertainties with the test set. + + Args: + X: The input. + y: The target. + fit_params: If it contains X_test and y_test, they are used to validate the model. + + Returns: The object itself. """ + if 'X_test' in fit_params and 'y_test' in fit_params: + X_test = fit_params['X_test'] + y_test = fit_params['y_test'] + else: + X_test = X + y_test = y # model parameter sizes - self.Nx: int = int(X_train.shape[1]) - self.Ny: int = int(Y_train.shape[1]) + self.Nx: int = int(X.shape[1]) + self.Ny: int = int(y.shape[1]) # initial parameter values A0: np.ndarray = np.eye(self.Nx, self.Ny).reshape(self.Nx*self.Ny) @@ -515,8 +412,8 @@ class FitModel(object): Returns: The loss value. """ - l_train = loss(x, X_train, Y_train) - l_test = loss(x, X_test, Y_test) + l_train = loss(x, X, y) + l_test = loss(x, X_test, y_test) self.loss_train += [l_train] self.loss_test += [l_test] @@ -531,7 +428,7 @@ class FitModel(object): Returns: The loss value. """ - l_train = loss(x, X_train, Y_train) + l_train = loss(x, X, y) return l_train grad_loss = grad(loss_train) @@ -603,3 +500,188 @@ class FitModel(object): result["Y_eps"] = np.exp(X @ self.A_eps + result["Y_unc"]) return result + +class Model(TransformerMixin, BaseEstimator): + """ + Object representing a previous fit of the model to be used to predict high-resolution + spectrum from a low-resolution one. + + Args: + channels: Selected channels to use as an input for the low resolution data. + n_pca_lr: Number of low-resolution data PCA components. + n_pca_hr: Number of high-resolution data PCA components. + high_res_sigma: Resolution of the high-resolution spectrometer in electron-Volts. + tof_start: Start looking at this index from the low-resolution spectrometer data. + Set to None to perform no selection + delta_tof: Number of components to take from the low-resolution spectrometer. + Set to None to perform no selection. + validation_size: Fraction (number between 0 and 1) of the data to take for + validation and systematic uncertainty estimate. + + """ + def __init__(self, + channels:List[str]=[f"channel_{j}_{k}" + for j, k in product(range(1, 5), ["A", "B", "C", "D"])], + n_pca_lr: int=600, + n_pca_hr: int=20, + high_res_sigma: float=0.2, + tof_start: Optional[int]=None, + delta_tof: Optional[int]=300, + validation_size: float=0.05): + # models + self.x_model = Pipeline([ + ('select', SelectRelevantLowResolution(channels, tof_start, delta_tof)), + ('pca', PCA(n_pca_lr, whiten=True)), + ('unc', UncertaintyHolder()), + ]) + self.y_model = Pipeline([('smoothen', HighResolutionSmoother(high_res_sigma)), + ('pca', PCA(n_pca_hr, whiten=True)), + ('unc', UncertaintyHolder()), + ]) + self.fit_model = FitModel() + + # size of the test subset + self.validation_size = validation_size + + def debug_peak_finding(self, low_res_data: Dict[str, np.ndarray], filename: str): + """ + Produce image to understand if the peak finding step worked well. + + Args: + low_res_data: Low resolution data with a dictionary containing the channel names. + filename: The file name where to save the plot. + + """ + self.x_model['select'].debug_peak_finding(low_res_data, filename) + + def preprocess_high_res(self, high_res_data: np.ndarray) -> np.ndarray: + """ + Preprocess high-resolution data to remove high requency components. + + Args: + high_res_data: The high-resolution data. + + Returns: Smoothened spectrum. + """ + return self.y_model['smoothen'].transform(high_res_data) + + def fit(self, low_res_data: Dict[str, np.ndarray], high_res_data: np.ndarray, high_res_photon_energy: np.ndarray) -> np.ndarray: + """ + Train the model. + + Args: + low_res_data: Low resolution data as a dictionary with the key set to `channel_{i}_{k}`, + where i is a number between 1 and 4 and k is a letter between A and D. + For each dictionary entry, a numpy array is expected with shape + (train_id, ToF channel). + high_res_data: Reference high resolution data with a one-to-one match to the + low resolution data in the train_id dimension. Shape (train_id, ToF channel). + high_res_photon_energy: Photon energy axis for the high-resolution data. + + Returns: Smoothened high resolution spectrum. + """ + x_t = self.x_model.fit_transform(low_res_data) + y_t = self.y_model.fit_transform(high_res_data, smoothen__energy=high_res_photon_energy) + self.fit_model.fit(x_t, y_t) + + # calculate the effect of the PCA + high_res = self.y_model['smoothen'].transform(high_res_data) + high_pca = self.y_model.transform(high_res_data) + high_pca_rec = self.y_model['pca'].inverse_transform(high_pca) + high_pca_unc = np.sqrt(np.mean((high_res - high_pca_rec)**2, axis=0, keepdims=True)) + self.y_model['unc'].set_uncertainty(high_pca_unc) + + low_res = self.x_model['select'].transform(low_res_data) + low_pca = self.x_model['pca'].transform(low_res) + low_pca_rec = self.x_model['pca'].inverse_transform(low_pca) + low_pca_unc = np.mean(np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)), axis=0, keepdims=True) + self.x_model['unc'].set_uncertainty(low_pca_unc) + + return high_res + + def check_compatibility(self, low_res_data: Dict[str, np.ndarray]) -> float: + """ + Check if a new low-resolution data source is compatible with the one used in training, by + comparing the effect of the trained PCA model on it. + + Args: + low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel). + + Returns: Ratio of root-mean-squared-error of the data reconstruction using the existing PCA model and the one from the original model. + """ + low_res = self.x_model['select'].transform(low_res_data) + low_pca = self.x_model['pca'].transform(low_res_data) + low_pca_rec = self.x_model['pca'].inverse_transform(low_pca) + low_pca_unc = self.x_model['unc'].uncertainty() + + #fig = plt.figure(figsize=(8, 16)) + #ax = plt.gca() + #ax.plot(low_res[0,...], + # c="b", + # label="LR") + #ax.plot(low_pca_rec[0,...], + # c="r", + # label="LR rec.") + #ax.set(title="", + # xlabel="Photon Spectrometer channel", + # ylabel="Low resolution spectrometer intensity") + #ax.legend() + #plt.savefig("check.png") + #plt.close(fig) + + low_pca_unc = np.sqrt(np.mean((low_res - low_pca_rec)**2, axis=1, keepdims=True)) + return low_pca_unc/low_pca_unc + + + def predict(self, low_res_data: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Predict a high-resolution spectrum from a low resolution given one. + The output includes the uncertainty in its second and third entries of the first dimension. + + Args: + low_res_data: Low resolution data as in the fit step with shape (train_id, channel, ToF channel). + + Returns: High resolution data with shape (train_id, energy channel) in a dictionary containing + the expected prediction in key "expected", the stat. uncertainty in key "unc" and + a (1, energy channel) array for the PCA syst. uncertainty in key "pca". + """ + low_pca = self.x_model.transform(low_res_data) + high_pca = self.fit_model.predict(low_pca) + n_trains = high_pca["Y"].shape[0] + pca_y = np.concatenate((high_pca["Y"], + high_pca["Y"] + high_pca["Y_eps"]), + axis=0) + high_res_predicted = self.y_model.inverse_transform(pca_y) + expected = high_res_predicted[:n_trains, :] + unc = high_res_predicted[n_trains:, :] - expected + return dict(expected=expected, + unc=unc, + pca=self.y_model['unc'].uncertainty()) + + def save(self, filename: str): + """ + Save the fit model in a file. + + Args: + filename: File name where to save this. + """ + joblib.dump([self.x_model, + self.y_model, + self.fit_model], + filename) + + @staticmethod + def load(filename: str) -> Model: + """ + Load model from a file. + + Args: + filename: Name of the file where to read the model from. + + """ + x_model, y_model, fit_model = joblib.load(filename) + obj = Model() + obj.x_model = x_model + obj.y_model = y_model + obj.fit_model = fit_model + diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py index b1a155c87c3025092dcd64c5d742afec5080d034..66cc7782a815e9f49a13465bd024ae8cadf2c896 100755 --- a/pes_to_spec/test/offline_analysis.py +++ b/pes_to_spec/test/offline_analysis.py @@ -139,18 +139,17 @@ def main(): spec_raw_pe[train_idx, :]) t += [time_ns() - start] t_names += ["Fit"] - spec_smooth = model.preprocess_high_res(spec_raw_int, spec_raw_pe) + spec_smooth = model.preprocess_high_res(spec_raw_int) print("Saving the model") start = time_ns() - model.save("model.h5") + model.save("model.joblib") t += [time_ns() - start] t_names += ["Save"] print("Loading the model") start = time_ns() - model = Model() - model.load("model.h5") + model = Model.load("model.joblib") t += [time_ns() - start] t_names += ["Load"]