From 9cc4630c26602db8f2cbab8d82ec155b7fea2013 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Sat, 4 Feb 2023 20:40:55 +0100
Subject: [PATCH] Using DataHolder object to save dictionaries.

---
 pes_to_spec/model.py | 65 +++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 61 insertions(+), 4 deletions(-)

diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 46b5fe7..891637e 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -22,6 +22,7 @@ from sklearn.base import clone, MetaEstimatorMixin
 from joblib import Parallel, delayed
 from functools import partial
 import multiprocessing as mp
+from copy import deepcopy
 
 from typing import Any, Dict, List, Optional, Union, Tuple
 
@@ -323,6 +324,62 @@ class UncertaintyHolder(TransformerMixin, BaseEstimator):
     def uncertainty(self):
         """The uncertainty recorded."""
         return self.unc
+    
+class DataHolder(TransformerMixin, BaseEstimator):
+    """
+    Keep track of relevance dictionaries.
+
+    """
+    def __init__(self, data: Dict[str, Any]=dict()):
+        self.data: Dict[str, Any] = data
+
+    def set_data(self, data: Dict[str, Any]):
+        """
+        Set.
+
+        Args:
+          data: Data.
+        """
+        self.data = deepcopy(data)
+        
+    def get_data(self) -> Dict[str, Any]:
+        """
+        Get.
+
+        """
+        return self.data
+
+    def fit(self, X, y=None) -> TransformerMixin:
+        """
+        Does nothing.
+
+        Args:
+          X: Irrelevant.
+          y: Irrelevant.
+
+        Returns: Itself.
+        """
+        return self
+
+    def transform(self, X: np.ndarray) -> np.ndarray:
+        """
+        Identity map.
+
+        Args:
+          X: The input.
+        """
+        return X
+
+    def inverse_transform(self, X: np.ndarray) -> np.ndarray:
+        """
+        Identity map.
+
+        Args:
+          X: The input.
+        """
+        return X
+
+
 
 class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
     """
@@ -771,8 +828,8 @@ class Model(TransformerMixin, BaseEstimator):
         joblib.dump([self.x_model,
                      self.y_model,
                      self.fit_model,
-                     self.channel_mean,
-                     self.channel_relevance
+                     DataHolder(self.channel_mean),
+                     DataHolder(self.channel_relevance)
                      ], filename, compress='zlib')
 
     @staticmethod
@@ -790,7 +847,7 @@ class Model(TransformerMixin, BaseEstimator):
         obj.x_model = x_model
         obj.y_model = y_model
         obj.fit_model = fit_model
-        obj.channel_mean = channel_mean
-        obj.channel_relevance = channel_relevance
+        obj.channel_mean = channel_mean.get_data()
+        obj.channel_relevance = channel_relevance.get_data()
         return obj
 
-- 
GitLab