Skip to content
Snippets Groups Projects
Commit 1e23d454 authored by Danilo Ferreira de Lima's avatar Danilo Ferreira de Lima
Browse files

Update configuration to speed it up.

parent 8b150174
No related branches found
No related tags found
1 merge request!12BNN optimization.
This commit is part of merge request !12. Comments created here will be created in the context of that merge request.
...@@ -200,10 +200,10 @@ class BNNModel(RegressorMixin, BaseEstimator): ...@@ -200,10 +200,10 @@ class BNNModel(RegressorMixin, BaseEstimator):
self.model = BNN(X.shape[1], y.shape[1]) self.model = BNN(X.shape[1], y.shape[1])
# prepare data loader # prepare data loader
B = 100 B = 50
loader = DataLoader(ds, loader = DataLoader(ds,
batch_size=B, batch_size=B,
num_workers=32, num_workers=20,
shuffle=True, shuffle=True,
#pin_memory=True, #pin_memory=True,
drop_last=True, drop_last=True,
...@@ -222,7 +222,7 @@ class BNNModel(RegressorMixin, BaseEstimator): ...@@ -222,7 +222,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
# train # train
self.model.train() self.model.train()
epochs = 1000 epochs = 500
for epoch in range(epochs): for epoch in range(epochs):
meter = {k: AverageMeter(k, ':6.3f') meter = {k: AverageMeter(k, ':6.3f')
for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')} for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')}
......
...@@ -144,7 +144,7 @@ def main(): ...@@ -144,7 +144,7 @@ def main():
parser.add_argument('-o', '--offset', type=int, metavar='INT', default=0, help='Train ID offset') parser.add_argument('-o', '--offset', type=int, metavar='INT', default=0, help='Train ID offset')
parser.add_argument('-c', '--xgm_cut', type=float, metavar='INTENSITY', default=500, help='XGM intensity threshold in uJ.') parser.add_argument('-c', '--xgm_cut', type=float, metavar='INTENSITY', default=500, help='XGM intensity threshold in uJ.')
parser.add_argument('-e', '--bnn', action="store_true", default=False, help='Use BNN?') parser.add_argument('-e', '--bnn', action="store_true", default=False, help='Use BNN?')
parser.add_argument('-w', '--weight', action="store_true", default=True, help='Whether to reweight data as a function of the pulse energy to make it invariant to that.') parser.add_argument('-w', '--weight', action="store_true", default=False, help='Whether to reweight data as a function of the pulse energy to make it invariant to that.')
args = parser.parse_args() args = parser.parse_args()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment