Skip to content
Snippets Groups Projects
Commit cda5dbe8 authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Clean up.

parent ef42358a
No related branches found
No related tags found
1 merge request!4Switched to ARDRegression to keep the code more maintainable
......@@ -23,8 +23,6 @@ from sklearn.model_selection import train_test_split
from sklearn.base import clone, MetaEstimatorMixin
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
from typing import Any, Dict, List, Optional, Union, Tuple
def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
......@@ -71,7 +69,6 @@ class HighResolutionSmoother(TransformerMixin, BaseEstimator):
Returns: The object itself.
"""
print("Storing high resolution energy")
self.energy = np.copy(fit_params["energy"])
if len(self.energy.shape) == 2:
self.energy = self.energy[0,:]
......@@ -86,7 +83,6 @@ class HighResolutionSmoother(TransformerMixin, BaseEstimator):
Returns: Smoothened out spectrum.
"""
print("Smoothing high-resolution spectrum")
if self.high_res_sigma <= 0:
return X
# use a default energy axis is none is given
......@@ -198,7 +194,6 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
Returns: Concatenated and pre-processed low-resolution data of shape (train_id, features).
"""
print("Selecting area close to the peak")
if self.tof_start is None:
raise NotImplementedError("The low-resolution data cannot be transformed before the prompt has been identified. Call the fit function first.")
items = [X[k] for k in self.channels]
......@@ -244,7 +239,6 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
Returns: The object itself.
"""
print("Estimating peak position")
self.tof_start = self.estimate_prompt_peak(X)
return self
......@@ -259,6 +253,7 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
"""
sum_low_res = - np.mean(sum(list(X.values())), axis=0)
peak_idx = self.estimate_prompt_peak(X)
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(8, 16))
ax = plt.gca()
ax.plot(np.arange(peak_idx-100, peak_idx+300),
......@@ -458,14 +453,12 @@ class MultiOutputWithStd(MetaEstimatorMixin, BaseEstimator):
"multi-output regression but has only one."
)
print(f"Fitting multiple regressors with n_jobs={self.n_jobs}")
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
delayed(_fit_estimator)(
self.estimator, X, y[:, i]
)
for i in range(y.shape[1])
)
print("End of fit")
return self
......@@ -480,7 +473,6 @@ class MultiOutputWithStd(MetaEstimatorMixin, BaseEstimator):
Multi-output targets predicted across multiple predictors.
Note: Separate models are generated for each predictor.
"""
print("Inferring ...")
y = Parallel(n_jobs=self.n_jobs)(
delayed(e.predict)(X, return_std) for e in self.estimators_
#delayed(e.predict)(X) for e in self.estimators_
......@@ -525,11 +517,11 @@ class Model(TransformerMixin, BaseEstimator):
x_model_steps = list()
x_model_steps += [('select', SelectRelevantLowResolution(channels, tof_start, delta_tof))]
if n_nonlinear_kernel > 0:
x_model_steps += [('fex', Pipeline([('prepca', IncrementalPCA(n_pca_lr, whiten=True, batch_size=n_pca_lr*2)),
x_model_steps += [('fex', Pipeline([('prepca', PCA(n_pca_lr, whiten=True)),
('nystroem', Nystroem(n_components=n_nonlinear_kernel, kernel='rbf', gamma=None, n_jobs=8)),
]))]
x_model_steps += [
('pca', IncrementalPCA(n_pca_lr, whiten=True, batch_size=n_pca_lr*2)),
('pca', PCA(n_pca_lr, whiten=True)),
('unc', UncertaintyHolder()),
]
self.x_model = Pipeline(x_model_steps)
......@@ -593,12 +585,9 @@ class Model(TransformerMixin, BaseEstimator):
Returns: Smoothened high resolution spectrum.
"""
print("Fitting x ...")
x_t = self.x_model.fit_transform(low_res_data)
print("Fitting y ...")
y_t = self.y_model.fit_transform(high_res_data, smoothen__energy=high_res_photon_energy)
#self.fit_model.set_params(fex__gamma=1.0/float(x_t.shape[0]))
print("Fitting linear model ...")
self.fit_model.fit(x_t, y_t)
# calculate the effect of the PCA
......
......@@ -27,7 +27,7 @@ dynamic = ["version", "readme"]
dependencies = [
"numpy>=1.21",
"scipy>=1.6",
"scikit-learn",
"scikit-learn==1.0.2",
"autograd",
"h5py"
]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment