diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py index 7c0b3411199f5337f3bf60f3d1ac40af3d8f2ac2..5f8832ee8448f975f21851ef514c934b581aae62 100644 --- a/pes_to_spec/bnn.py +++ b/pes_to_spec/bnn.py @@ -456,7 +456,7 @@ class BNNModel(RegressorMixin, BaseEstimator): Args: """ - def __init__(self, state_dict=None, rvm: bool=False): + def __init__(self, state_dict=None, rvm: bool=False, n_epochs: int=250): if state_dict is not None: Nx = state_dict["model.0.weight_mu"].shape[1] Ny = state_dict["model.2.weight_mu"].shape[0] @@ -465,6 +465,7 @@ class BNNModel(RegressorMixin, BaseEstimator): else: self.model = BNN(rvm=rvm) self.rvm = rvm + self.n_epochs = n_epochs self.model.eval() def state_dict(self) -> Dict[str, Any]: @@ -517,8 +518,7 @@ class BNNModel(RegressorMixin, BaseEstimator): # train self.model.train() - epochs = 250 - for epoch in range(epochs): + for epoch in range(self.n_epochs): meter = {k: AverageMeter(k, ':6.3f') for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')} progress = ProgressMeter( diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py index 1af3a3bf849cd06f2b47d64e5a8a99bc4d254abb..4d93ec82f883e304963e96c4c698b2cd100c2a1d 100644 --- a/pes_to_spec/model.py +++ b/pes_to_spec/model.py @@ -578,6 +578,7 @@ class Model(TransformerMixin, BaseEstimator): validation and systematic uncertainty estimate. model_type: Which model to use. "bnn" for a BNN, "bnn_rvm" for a BNN with RVM, "ridge" for Ridge and "ard" for ARD. n_peaks: Minimum numbr of peaks in the grating spectrometer. + n_bnn_epochs: Number of BNN epochs for training. """ def __init__(self, @@ -591,6 +592,7 @@ class Model(TransformerMixin, BaseEstimator): validation_size: float=0.05, model_type: Literal["bnn", "bnn_rvm", "ridge", "ard"]="ard", n_peaks: int=0, + n_bnn_epochs: int=500, ): self.high_res_sigma = high_res_sigma # models @@ -609,9 +611,9 @@ class Model(TransformerMixin, BaseEstimator): self.ood = {ch: UncorrelatedDeviation(sigma=5) for ch in channels+['full']} if model_type == "bnn": - self.fit_model = BNNModel() + self.fit_model = BNNModel(n_epochs=n_bnn_epochs) elif model_type == "bnn_rvm": - self.fit_model = BNNModel(rvm=True) + self.fit_model = BNNModel(n_epochs=n_bnn_epochs, rvm=True) elif model_type == "ridge": self.fit_model = MultiOutputRidgeWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8) elif model_type == "ard":