diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py index eefa81288ee7d23ed21c5acf89efb2a52ec514a9..aaf850988631569f696274aea52fd1fca3625f28 100644 --- a/pes_to_spec/bnn.py +++ b/pes_to_spec/bnn.py @@ -82,8 +82,8 @@ class BNN(nn.Module): # and the only regularization is to prevent the weights from becoming > 18 + 3 sqrt(var) ~= 50, making this a very loose regularization. # An alternative would be to set the (alpha, beta) both to very low values, whichmakes the hyper prior become closer to the non-informative Jeffrey's prior. # Using this alternative (ie: (0.1, 0.1) for the weights' hyper prior) leads to very large lambda and numerical issues with the fit. - self.alpha_lambda = 3.0 - self.beta_lambda = 6.0 + self.alpha_lambda = 0.1 + self.beta_lambda = 0.1 # Hyperprior choice on the likelihood noise level: # The likelihood noise level is controlled by sigma in the likelihood and it should be allowed to be very broad, but different @@ -92,8 +92,8 @@ class BNN(nn.Module): # Making both alpha and beta small makes the gamma distribution closer to the Jeffey's prior, which makes it non-informative # This seems to lead to a larger training time, though. # Since, after standardization, we know to expect the variance to be of order (1), we can select also alpha and beta leading to high variance in this range - self.alpha_sigma = 2.0 - self.beta_sigma = 0.15 + self.alpha_sigma = 0.1 + self.beta_sigma = 0.1 self.model = nn.Sequential( bnn.BayesLinear(prior_mu=0.0, @@ -201,7 +201,7 @@ class BNNModel(RegressorMixin, BaseEstimator): self.model = BNN(X.shape[1], y.shape[1]) # prepare data loader - B = 5 + B = 100 loader = DataLoader(ds, batch_size=B, num_workers=5, @@ -223,7 +223,7 @@ class BNNModel(RegressorMixin, BaseEstimator): # train self.model.train() - epochs = 200 + epochs = 1000 for epoch in range(epochs): meter = {k: AverageMeter(k, ':6.3f') for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')} diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py index 5ede9532259921dd4033e2f67dfc5492073a7ac7..2da28ebe6d2c4cca4206b69e5ad7d67266a2d4de 100755 --- a/pes_to_spec/test/offline_analysis.py +++ b/pes_to_spec/test/offline_analysis.py @@ -144,7 +144,7 @@ def main(): parser.add_argument('-o', '--offset', type=int, metavar='INT', default=0, help='Train ID offset') parser.add_argument('-c', '--xgm_cut', type=float, metavar='INTENSITY', default=500, help='XGM intensity threshold in uJ.') parser.add_argument('-e', '--bnn', action="store_true", default=False, help='Use BNN?') - parser.add_argument('-w', '--weight', action="store_true", default=False, help='Whether to reweight data as a function of the pulse energy to make it invariant to that.') + parser.add_argument('-w', '--weight', action="store_true", default=True, help='Whether to reweight data as a function of the pulse energy to make it invariant to that.') args = parser.parse_args()