diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py
index 7c0b3411199f5337f3bf60f3d1ac40af3d8f2ac2..5f8832ee8448f975f21851ef514c934b581aae62 100644
--- a/pes_to_spec/bnn.py
+++ b/pes_to_spec/bnn.py
@@ -456,7 +456,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
 
     Args:
     """
-    def __init__(self, state_dict=None, rvm: bool=False):
+    def __init__(self, state_dict=None, rvm: bool=False, n_epochs: int=250):
         if state_dict is not None:
             Nx = state_dict["model.0.weight_mu"].shape[1]
             Ny = state_dict["model.2.weight_mu"].shape[0]
@@ -465,6 +465,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
         else:
             self.model = BNN(rvm=rvm)
         self.rvm = rvm
+        self.n_epochs = n_epochs
         self.model.eval()
 
     def state_dict(self) -> Dict[str, Any]:
@@ -517,8 +518,7 @@ class BNNModel(RegressorMixin, BaseEstimator):
 
         # train
         self.model.train()
-        epochs = 250
-        for epoch in range(epochs):
+        for epoch in range(self.n_epochs):
             meter = {k: AverageMeter(k, ':6.3f')
                     for k in ('loss', '-log(lkl)', '-log(prior)', '-log(hyper)', 'sigma', 'w.prec.')}
             progress = ProgressMeter(
diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 1af3a3bf849cd06f2b47d64e5a8a99bc4d254abb..4d93ec82f883e304963e96c4c698b2cd100c2a1d 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -578,6 +578,7 @@ class Model(TransformerMixin, BaseEstimator):
                        validation and systematic uncertainty estimate.
       model_type: Which model to use. "bnn" for a BNN, "bnn_rvm" for a BNN with RVM, "ridge" for Ridge and "ard" for ARD.
       n_peaks: Minimum numbr of peaks in the grating spectrometer.
+      n_bnn_epochs: Number of BNN epochs for training.
 
     """
     def __init__(self,
@@ -591,6 +592,7 @@ class Model(TransformerMixin, BaseEstimator):
                  validation_size: float=0.05,
                  model_type: Literal["bnn", "bnn_rvm", "ridge", "ard"]="ard",
                  n_peaks: int=0,
+                 n_bnn_epochs: int=500,
                 ):
         self.high_res_sigma = high_res_sigma
         # models
@@ -609,9 +611,9 @@ class Model(TransformerMixin, BaseEstimator):
         self.ood = {ch: UncorrelatedDeviation(sigma=5)
                     for ch in channels+['full']}
         if model_type == "bnn":
-            self.fit_model = BNNModel()
+            self.fit_model = BNNModel(n_epochs=n_bnn_epochs)
         elif model_type == "bnn_rvm":
-            self.fit_model = BNNModel(rvm=True)
+            self.fit_model = BNNModel(n_epochs=n_bnn_epochs, rvm=True)
         elif model_type == "ridge":
             self.fit_model = MultiOutputRidgeWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
         elif model_type == "ard":