Skip to content
Snippets Groups Projects
Commit 8ef4ee33 authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Many bug fixes when writing the PCA model.

parent 02eded28
No related branches found
No related tags found
No related merge requests found
...@@ -34,14 +34,15 @@ def save_pca(pca_obj: Union[IncrementalPCA, PCA], pca_group: h5py.Group): ...@@ -34,14 +34,15 @@ def save_pca(pca_obj: Union[IncrementalPCA, PCA], pca_group: h5py.Group):
"singular_values_", "singular_values_",
"mean_"] "mean_"]
attrs = ["n_components_", attrs = ["n_components_",
"n_features_", #"n_features_",
"n_samples_", #"n_samples_",
"noise_variance_", "noise_variance_",
"n_features_in_"] #"n_features_in_"
]
for p in props: for p in props:
pca_group.create_dataset(p, getattr(pca_obj, p)) pca_group.create_dataset(p, data=getattr(pca_obj, p))
for a in attrs: for a in attrs:
pca_group.attrs[p] = getattr(pca_obj, a) pca_group.attrs[a] = getattr(pca_obj, a)
def load_pca(pca_obj: Union[IncrementalPCA, PCA], pca_group: h5py.Group) -> Union[IncrementalPCA, PCA]: def load_pca(pca_obj: Union[IncrementalPCA, PCA], pca_group: h5py.Group) -> Union[IncrementalPCA, PCA]:
""" """
...@@ -60,14 +61,15 @@ def load_pca(pca_obj: Union[IncrementalPCA, PCA], pca_group: h5py.Group) -> Unio ...@@ -60,14 +61,15 @@ def load_pca(pca_obj: Union[IncrementalPCA, PCA], pca_group: h5py.Group) -> Unio
"singular_values_", "singular_values_",
"mean_"] "mean_"]
attrs = ["n_components_", attrs = ["n_components_",
"n_features_", #"n_features_",
"n_samples_", #"n_samples_",
"noise_variance_", "noise_variance_",
"n_features_in_"] #"n_features_in_"
]
for p in props: for p in props:
setattr(pca_obj, p, pca_group[p]) setattr(pca_obj, p, pca_group[p][()])
for a in attrs: for a in attrs:
setattr(pca_obj, a, pca_group[a]) setattr(pca_obj, a, pca_group.attrs[a])
return pca_obj return pca_obj
class PromptNotFoundError(Exception): class PromptNotFoundError(Exception):
...@@ -323,7 +325,7 @@ class Model(object): ...@@ -323,7 +325,7 @@ class Model(object):
""" """
with h5py.File(filename, 'r') as hf: with h5py.File(filename, 'r') as hf:
d = {k: hf[k][()] for k in hf.keys()} d = {k: hf[k][()] for k in hf.keys() if not isinstance(hf[k], h5py.Group)}
d.update({k: hf.attrs[k] for k in hf.attrs}) d.update({k: hf.attrs[k] for k in hf.attrs})
self.fit_model.from_dict(d) self.fit_model.from_dict(d)
for key in self.parameters().keys(): for key in self.parameters().keys():
...@@ -333,8 +335,8 @@ class Model(object): ...@@ -333,8 +335,8 @@ class Model(object):
# files # files
lr_pca = hf["/lr_pca/"] lr_pca = hf["/lr_pca/"]
hr_pca = hf["/hr_pca/"] hr_pca = hf["/hr_pca/"]
self.lr_pca = IncrementalPCA(self.n_pca_lr) self.lr_pca = IncrementalPCA(self.n_pca_lr, whiten=True)
self.hr_pca = PCA(self.n_pca_hr) self.hr_pca = PCA(self.n_pca_hr, whiten=True)
self.lr_pca = load_pca(self.lr_pca, lr_pca) self.lr_pca = load_pca(self.lr_pca, lr_pca)
self.hr_pca = load_pca(self.hr_pca, hr_pca) self.hr_pca = load_pca(self.hr_pca, hr_pca)
#self.lr_pca = joblib.load(lr_pca_filename) #self.lr_pca = joblib.load(lr_pca_filename)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment