diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 46b5fe7e0184a79a4fc577e9c683be2d41e25274..891637ef43121b40213b97a3ba48a449cbf76afc 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -22,6 +22,7 @@ from sklearn.base import clone, MetaEstimatorMixin from joblib import Parallel, delayed from functools import partial import multiprocessing as mp +from copy import deepcopy from typing import Any, Dict, List, Optional, Union, Tuple @@ -323,6 +324,62 @@ class UncertaintyHolder(TransformerMixin, BaseEstimator): def uncertainty(self): """The uncertainty recorded.""" return self.unc + +class DataHolder(TransformerMixin, BaseEstimator): + """ + Keep track of relevance dictionaries. + + """ + def __init__(self, data: Dict[str, Any]=dict()): + self.data: Dict[str, Any] = data + + def set_data(self, data: Dict[str, Any]): + """ + Set. + + Args: + data: Data. + """ + self.data = deepcopy(data) + + def get_data(self) -> Dict[str, Any]: + """ + Get. + + """ + return self.data + + def fit(self, X, y=None) -> TransformerMixin: + """ + Does nothing. + + Args: + X: Irrelevant. + y: Irrelevant. + + Returns: Itself. + """ + return self + + def transform(self, X: np.ndarray) -> np.ndarray: + """ + Identity map. + + Args: + X: The input. + """ + return X + + def inverse_transform(self, X: np.ndarray) -> np.ndarray: + """ + Identity map. + + Args: + X: The input. + """ + return X + + class SelectRelevantLowResolution(TransformerMixin, BaseEstimator): """ @@ -771,8 +828,8 @@ class Model(TransformerMixin, BaseEstimator): joblib.dump([self.x_model, self.y_model, self.fit_model, - self.channel_mean, - self.channel_relevance + DataHolder(self.channel_mean), + DataHolder(self.channel_relevance) ], filename, compress='zlib') @staticmethod @@ -790,7 +847,7 @@ class Model(TransformerMixin, BaseEstimator): obj.x_model = x_model obj.y_model = y_model obj.fit_model = fit_model - obj.channel_mean = channel_mean - obj.channel_relevance = channel_relevance + obj.channel_mean = channel_mean.get_data() + obj.channel_relevance = channel_relevance.get_data() return obj