From 8ef4ee33df5615e302cec47db17164f8c429532c Mon Sep 17 00:00:00 2001 From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de> Date: Mon, 19 Dec 2022 20:11:00 +0100 Subject: [PATCH] Many bug fixes when writing the PCA model. --- pes_to_spec/model.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 2f73ec2..8e5d701 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -34,14 +34,15 @@ def save_pca(pca_obj: Union[IncrementalPCA, PCA], pca_group: h5py.Group): "singular_values_", "mean_"] attrs = ["n_components_", - "n_features_", - "n_samples_", + #"n_features_", + #"n_samples_", "noise_variance_", - "n_features_in_"] + #"n_features_in_" + ] for p in props: - pca_group.create_dataset(p, getattr(pca_obj, p)) + pca_group.create_dataset(p, data=getattr(pca_obj, p)) for a in attrs: - pca_group.attrs[p] = getattr(pca_obj, a) + pca_group.attrs[a] = getattr(pca_obj, a) def load_pca(pca_obj: Union[IncrementalPCA, PCA], pca_group: h5py.Group) -> Union[IncrementalPCA, PCA]: """ @@ -60,14 +61,15 @@ def load_pca(pca_obj: Union[IncrementalPCA, PCA], pca_group: h5py.Group) -> Unio "singular_values_", "mean_"] attrs = ["n_components_", - "n_features_", - "n_samples_", + #"n_features_", + #"n_samples_", "noise_variance_", - "n_features_in_"] + #"n_features_in_" + ] for p in props: - setattr(pca_obj, p, pca_group[p]) + setattr(pca_obj, p, pca_group[p][()]) for a in attrs: - setattr(pca_obj, a, pca_group[a]) + setattr(pca_obj, a, pca_group.attrs[a]) return pca_obj class PromptNotFoundError(Exception): @@ -323,7 +325,7 @@ class Model(object): """ with h5py.File(filename, 'r') as hf: - d = {k: hf[k][()] for k in hf.keys()} + d = {k: hf[k][()] for k in hf.keys() if not isinstance(hf[k], h5py.Group)} d.update({k: hf.attrs[k] for k in hf.attrs}) self.fit_model.from_dict(d) for key in self.parameters().keys(): @@ -333,8 +335,8 @@ class Model(object): # files lr_pca = hf["/lr_pca/"] hr_pca = hf["/hr_pca/"] - self.lr_pca = IncrementalPCA(self.n_pca_lr) - self.hr_pca = PCA(self.n_pca_hr) + self.lr_pca = IncrementalPCA(self.n_pca_lr, whiten=True) + self.hr_pca = PCA(self.n_pca_hr, whiten=True) self.lr_pca = load_pca(self.lr_pca, lr_pca) self.hr_pca = load_pca(self.hr_pca, hr_pca) #self.lr_pca = joblib.load(lr_pca_filename) -- GitLab