Skip to content
Snippets Groups Projects
Commit 89ded2c2 authored by Danilo Enoque Ferreira de Lima's avatar Danilo Enoque Ferreira de Lima
Browse files

Merge branch 'handle_pedestal' into 'main'

Handle pedestal in PES and use GPU if available in BNN

See merge request !19
parents a2b8e842 32451c16
No related branches found
No related tags found
1 merge request!19Handle pedestal in PES and use GPU if available in BNN
Pipeline #120515 canceled
......@@ -15,6 +15,7 @@ default:
#- pip install torch --index-url https://download.pytorch.org/whl/cpu
- pip install joblib scikit-learn
- pip install matplotlib lmfit seaborn extra_data
- pip install dask
stages:
- test
......
......@@ -526,6 +526,8 @@ class BNNModel(RegressorMixin, BaseEstimator):
# train
self.model.train()
if torch.cuda.is_available():
self.model = self.model.to('cuda')
for epoch in range(self.n_epochs):
meter = {k: AverageMeter(k, ':6.3f')
for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')}
......@@ -535,6 +537,10 @@ class BNNModel(RegressorMixin, BaseEstimator):
prefix="Epoch: [{}]".format(epoch))
for i, batch in enumerate(loader):
x_b, y_b, w_b = batch
if torch.cuda.is_available():
x_b = x_b.to('cuda')
y_b = y_b.to('cuda')
w_b = w_b.to('cuda')
y_b_pred = self.model(x_b)
nll = self.model.neg_log_likelihood(y_b_pred, y_b, w_b)
......@@ -558,6 +564,8 @@ class BNNModel(RegressorMixin, BaseEstimator):
self.model.prune()
self.model.eval()
if torch.cuda.is_available():
self.model = self.model.to('cpu')
return self
......
......@@ -2,7 +2,9 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, Union, Tuple, Literal
import sys
import joblib
import dask.array as da
import numpy as np
import scipy
......@@ -282,6 +284,7 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
self.poly = poly
self.mean = dict()
self.std = dict()
self.pedestal = {ch: 0.0 for ch in self.channels}
def transform(self, X: Dict[str, np.ndarray],
keep_dictionary_structure: bool=False,
......@@ -309,9 +312,13 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
if self.delta_tof is not None:
first = max(0, self.tof_start - self.delta_tof)
last = min(X[self.channels[0]].shape[1], self.tof_start + self.delta_tof)
y = {channel: [item[:, (first + delta):(last + delta)] for delta in pulse_spacing[channel]]
y = {channel: [item[:, (first + delta):(last + delta)] - self.pedestal[channel]
for delta in pulse_spacing[channel]]
for channel, item in X.items()
if channel in self.channels}
# convert to Numpy if it is a Dask array
y = {ch: [item.compute() if isinstance(item, da.Array) else item for item in v]
for ch, v in y.items()}
# pad it with zeros, if we reach the edge of the array
for channel in y.keys():
y[channel] = [np.pad(y[channel][j], ((0, 0), (0, 2*self.delta_tof - y[channel][j].shape[1])))
......@@ -336,7 +343,10 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
Returns: The index.
"""
# reduce on channel and on train ID
sum_low_res = - np.mean(sum(list(X.values())), axis=0)
sum_low_res = np.mean(sum([-(v - self.pedestal[ch]) for ch, v in X.items()]), axis=0)
# convert to Numpy if it is a Dask array
if isinstance(sum_low_res, da.Array):
sum_low_res = sum_low_res.compute()
axis = np.arange(0.0, sum_low_res.shape[0], 1.0)
#widths = np.arange(10, 50, step=5)
#peak_idx = find_peaks_cwt(sum_low_res, widths)
......@@ -367,6 +377,12 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
Returns: The object itself.
"""
# estimate pedestal
self.pedestal = {ch: np.mean(v[:10])
for ch, v in X.items()}
# convert to Numpy if it is a Dask array
self.pedestal = {ch: v.compute() if isinstance(v, da.Array) else v
for ch, v in self.pedestal.items()}
self.tof_start = self.estimate_prompt_peak(X)
X_tr = self.transform(X, keep_dictionary_structure=True)
self.mean = {ch: np.mean(X_tr[ch], axis=0, keepdims=True)
......@@ -863,6 +879,11 @@ class Model(TransformerMixin, BaseEstimator):
print("Finding region-of-interest")
self.x_select.fit(low_res_data)
low_res_select = self.x_select.transform(low_res_data, pulse_energy=pulse_energy)
low_res_selected_dict = self.x_select.transform(low_res_data, keep_dictionary_structure=True)
# after this we do not need low_res_data anymore
del low_res_data
# keep the number of pulses
B, P, _ = low_res_select.shape
low_res_select = low_res_select.reshape((B*P, -1))
......@@ -873,6 +894,8 @@ class Model(TransformerMixin, BaseEstimator):
cov = EllipticEnvelope(random_state=0).fit(lr_sums)
good_low_res = cov.predict(lr_sums)
filter_lr = (good_low_res > 0)
del cov
low_res_filter = low_res_select[filter_hr & filter_lr, :]
high_res_filter = high_res_data[filter_hr & filter_lr, :]
weights_filter = weights
......@@ -889,6 +912,7 @@ class Model(TransformerMixin, BaseEstimator):
if len(n_components) > 0:
n_components = n_components[0]
n_components = max(600, n_components)
del pca_test
print(f"Using {n_components} comp. for PES PCA.")
self.x_model.set_params(pca__n_components=n_components)
......@@ -903,6 +927,7 @@ class Model(TransformerMixin, BaseEstimator):
if len(n_components_hr) > 0:
n_components_hr = n_components_hr[0]
n_components_hr = max(20, n_components_hr)
del pca_test
print(f"Using {n_components_hr} comp. for grating spec. PCA.")
self.y_model.set_params(pca__n_components=n_components_hr)
......@@ -1002,12 +1027,10 @@ class Model(TransformerMixin, BaseEstimator):
return high_res.reshape((B, P, -1))
# for consistency check per channel
selection_model = self.x_select
low_res_selected = selection_model.transform(low_res_data, keep_dictionary_structure=True)
for channel in self.get_channels():
B, P, _ = low_res_selected[channel].shape
B, P, _ = low_res_selected_dict[channel].shape
print(f"Calculate PCA on {channel}")
low_pca = self.channel_pca[channel].fit_transform(low_res_selected[channel].reshape(B*P, -1))
low_pca = self.channel_pca[channel].fit_transform(low_res_selected_dict[channel].reshape(B*P, -1))
self.ood[channel].fit(low_pca)
print("End of fit.")
......
......@@ -30,6 +30,7 @@ dependencies = [
"lmfit",
"scikit-learn>=1.2.0",
"torch",
"dask",
]
[project.optional-dependencies]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment