Skip to content
Snippets Groups Projects
Commit e43fba93 authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Adjusted setup to get better optimized BNN results.

parent d360c3d4
No related branches found
No related tags found
No related merge requests found
...@@ -82,8 +82,8 @@ class BNN(nn.Module): ...@@ -82,8 +82,8 @@ class BNN(nn.Module):
# and the only regularization is to prevent the weights from becoming > 18 + 3 sqrt(var) ~= 50, making this a very loose regularization. # and the only regularization is to prevent the weights from becoming > 18 + 3 sqrt(var) ~= 50, making this a very loose regularization.
# An alternative would be to set the (alpha, beta) both to very low values, whichmakes the hyper prior become closer to the non-informative Jeffrey's prior. # An alternative would be to set the (alpha, beta) both to very low values, whichmakes the hyper prior become closer to the non-informative Jeffrey's prior.
# Using this alternative (ie: (0.1, 0.1) for the weights' hyper prior) leads to very large lambda and numerical issues with the fit. # Using this alternative (ie: (0.1, 0.1) for the weights' hyper prior) leads to very large lambda and numerical issues with the fit.
self.alpha_lambda = 3.0 self.alpha_lambda = 0.1
self.beta_lambda = 6.0 self.beta_lambda = 0.1
# Hyperprior choice on the likelihood noise level: # Hyperprior choice on the likelihood noise level:
# The likelihood noise level is controlled by sigma in the likelihood and it should be allowed to be very broad, but different # The likelihood noise level is controlled by sigma in the likelihood and it should be allowed to be very broad, but different
...@@ -92,8 +92,8 @@ class BNN(nn.Module): ...@@ -92,8 +92,8 @@ class BNN(nn.Module):
# Making both alpha and beta small makes the gamma distribution closer to the Jeffey's prior, which makes it non-informative # Making both alpha and beta small makes the gamma distribution closer to the Jeffey's prior, which makes it non-informative
# This seems to lead to a larger training time, though. # This seems to lead to a larger training time, though.
# Since, after standardization, we know to expect the variance to be of order (1), we can select also alpha and beta leading to high variance in this range # Since, after standardization, we know to expect the variance to be of order (1), we can select also alpha and beta leading to high variance in this range
self.alpha_sigma = 2.0 self.alpha_sigma = 0.1
self.beta_sigma = 0.15 self.beta_sigma = 0.1
self.model = nn.Sequential( self.model = nn.Sequential(
bnn.BayesLinear(prior_mu=0.0, bnn.BayesLinear(prior_mu=0.0,
...@@ -201,7 +201,7 @@ class BNNModel(RegressorMixin, BaseEstimator): ...@@ -201,7 +201,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
self.model = BNN(X.shape[1], y.shape[1]) self.model = BNN(X.shape[1], y.shape[1])
# prepare data loader # prepare data loader
B = 5 B = 100
loader = DataLoader(ds, loader = DataLoader(ds,
batch_size=B, batch_size=B,
num_workers=5, num_workers=5,
...@@ -223,7 +223,7 @@ class BNNModel(RegressorMixin, BaseEstimator): ...@@ -223,7 +223,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
# train # train
self.model.train() self.model.train()
epochs = 200 epochs = 1000
for epoch in range(epochs): for epoch in range(epochs):
meter = {k: AverageMeter(k, ':6.3f') meter = {k: AverageMeter(k, ':6.3f')
for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')} for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')}
......
...@@ -144,7 +144,7 @@ def main(): ...@@ -144,7 +144,7 @@ def main():
parser.add_argument('-o', '--offset', type=int, metavar='INT', default=0, help='Train ID offset') parser.add_argument('-o', '--offset', type=int, metavar='INT', default=0, help='Train ID offset')
parser.add_argument('-c', '--xgm_cut', type=float, metavar='INTENSITY', default=500, help='XGM intensity threshold in uJ.') parser.add_argument('-c', '--xgm_cut', type=float, metavar='INTENSITY', default=500, help='XGM intensity threshold in uJ.')
parser.add_argument('-e', '--bnn', action="store_true", default=False, help='Use BNN?') parser.add_argument('-e', '--bnn', action="store_true", default=False, help='Use BNN?')
parser.add_argument('-w', '--weight', action="store_true", default=False, help='Whether to reweight data as a function of the pulse energy to make it invariant to that.') parser.add_argument('-w', '--weight', action="store_true", default=True, help='Whether to reweight data as a function of the pulse energy to make it invariant to that.')
args = parser.parse_args() args = parser.parse_args()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment