From 9106934fb8ba83affc26a2428a60ab7366ac1eff Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Tue, 18 Apr 2023 17:37:55 +0200
Subject: [PATCH] Clean up.

---
 pes_to_spec/bnn.py   | 47 +++++++++++++++++++-------------------------
 pes_to_spec/model.py |  2 +-
 2 files changed, 21 insertions(+), 28 deletions(-)

diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py
index 234fed3..eefa812 100644
--- a/pes_to_spec/bnn.py
+++ b/pes_to_spec/bnn.py
@@ -7,7 +7,7 @@ from scipy.special import gamma
 import torch
 import torch.nn as nn
 import torchbnn as bnn
-from torch.utils.data import Dataset, DataLoader
+from torch.utils.data import TensorDataset, DataLoader
 
 class AverageMeter(object):
     """Computes and stores the average and current value"""
@@ -41,7 +41,7 @@ class ProgressMeter(object):
     def display(self, batch):
         entries = [self.prefix + self.batch_fmtstr.format(batch)]
         entries += [str(meter) for meter in self.meters]
-        print('\t'.join(entries))
+        print('  '.join(entries))
 
     def _get_batch_fmtstr(self, num_batches):
         num_digits = len(str(num_batches // 1))
@@ -151,23 +151,11 @@ class BNN(nn.Module):
         #return 0
         return torch.exp(-0.5*self.log_isigma2[0])
 
-    def l(self) -> torch.Tensor:
+    def w_precision(self) -> torch.Tensor:
         """
-            Get the weights std. dev.
+            Get the weights precision.
         """
-        return torch.exp(-0.5*self.log_ilambda2[0])
-
-class BNNDataset(Dataset):
-    def __init__(self, x: np.ndarray, y: np.ndarray, w: np.ndarray):
-        self.x = x
-        self.y = y
-        self.w = w
-        assert len(x) == len(y) and len(x) == len(w)
-    def __len__(self) -> int:
-        """How many samples do I have?"""
-        return len(self.x)
-    def __getitem__(self, idx):
-        return {"x": self.x[idx, :], "y": self.y[idx, :], "w": self.w[idx, :]}
+        return torch.exp(self.log_ilambda2[0])
 
 class BNNModel(RegressorMixin, BaseEstimator):
     """
@@ -204,14 +192,23 @@ class BNNModel(RegressorMixin, BaseEstimator):
             weights = np.ones(len(X), dtype=np.float32)
         if len(weights.shape) == 1:
             weights = weights[:, np.newaxis]
-        ds = BNNDataset(X, y, weights)
+
+        ds = TensorDataset(torch.from_numpy(X),
+                           torch.from_numpy(y),
+                           torch.from_numpy(weights))
 
         # create model
         self.model = BNN(X.shape[1], y.shape[1])
 
         # prepare data loader
         B = 5
-        loader = DataLoader(ds, batch_size=B)
+        loader = DataLoader(ds,
+                            batch_size=B,
+                            num_workers=5,
+                            shuffle=True,
+                            #pin_memory=True,
+                            drop_last=True,
+                            )
         optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
         number_of_batches = len(ds)/float(B)
         weight_prior = 1.0/float(number_of_batches)
@@ -226,18 +223,16 @@ class BNNModel(RegressorMixin, BaseEstimator):
 
         # train
         self.model.train()
-        epochs = 100
+        epochs = 200
         for epoch in range(epochs):
             meter = {k: AverageMeter(k, ':6.3f')
-                    for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'lambda')}
+                    for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')}
             progress = ProgressMeter(
                             len(loader),
                             meter.values(),
                             prefix="Epoch: [{}]".format(epoch))
             for i, batch in enumerate(loader):
-                x_b = batch["x"]
-                y_b = batch["y"]
-                w_b = batch["w"]
+                x_b, y_b, w_b = batch
                 y_b_pred = self.model(x_b)
 
                 nll = self.model.neg_log_likelihood(y_b_pred, y_b, w_b)
@@ -254,10 +249,8 @@ class BNNModel(RegressorMixin, BaseEstimator):
                 meter['-log(prior)'].update(nlprior.detach().cpu().item(), B)
                 meter['-log(hyper)'].update(nlhyper.detach().cpu().item(), B)
                 meter['sigma'].update(self.model.aleatoric_uncertainty().detach().cpu().item(), B)
-                meter['lambda'].update(self.model.l().detach().cpu().item(), B)
+                meter['w.prec.'].update(self.model.w_precision().detach().cpu().item(), B)
 
-                if i % 100 == 0:
-                    progress.display(i)
             progress.display(len(loader))
         self.model.eval()
 
diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 39b6c4b..045d38e 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -523,7 +523,7 @@ class Model(TransformerMixin, BaseEstimator):
     def __init__(self,
                  channels:List[str]=[f"channel_{j}_{k}"
                                      for j, k in product(range(1, 5), ["A", "B", "C", "D"])],
-                 n_pca_lr: int=600,
+                 n_pca_lr: int=200,
                  n_pca_hr: int=20,
                  high_res_sigma: float=0.2,
                  tof_start: Optional[int]=None,
-- 
GitLab