diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 214ce31125dad632724c6ef8e71ae1d48247b78e..10df31f1c32ca25c43316ef22478a09171c50d9f 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -8,8 +8,8 @@ from sklearn.decomposition import IncrementalPCA, PCA
 from sklearn.base import TransformerMixin, BaseEstimator
 from sklearn.base import OutlierMixin
 from sklearn.pipeline import Pipeline
-from sklearn.kernel_approximation import Nystroem
 from sklearn.linear_model import BayesianRidge
+from sklearn.linear_model import ARDRegression
 from sklearn.metrics import accuracy_score
 from scipy.stats import gaussian_kde
 from itertools import product
@@ -20,7 +20,7 @@ from copy import deepcopy
 
 from pes_to_spec.bnn import BNNModel
 
-from typing import Any, Dict, List, Optional, Union, Tuple
+from typing import Any, Dict, List, Optional, Union, Tuple, Literal
 
 def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
     """Returns list of train IDs common to sets a, b and c."""
@@ -354,9 +354,12 @@ class SelectRelevantLowResolution(TransformerMixin, BaseEstimator):
         plt.savefig(filename)
         plt.close(fig)
 
-def _fit_estimator(estimator, X: np.ndarray, y: np.ndarray, w: np.ndarray):
+def _fit_estimator(estimator, X: np.ndarray, y: np.ndarray, w: Optional[np.ndarray]=None):
     estimator = clone(estimator)
-    estimator.fit(X, y, w)
+    if w is None:
+        estimator.fit(X, y)
+    else:
+        estimator.fit(X, y, w)
     return estimator
 
 class MultiOutputWithStd(MetaEstimatorMixin, BaseEstimator):
@@ -382,9 +385,6 @@ class MultiOutputWithStd(MetaEstimatorMixin, BaseEstimator):
                 "y must have at least two dimensions for "
                 "multi-output regression but has only one."
             )
-        if weights is None:
-            weights = np.ones(y.shape[0])
-
         self.estimators_ = Parallel(n_jobs=self.n_jobs)(
             delayed(_fit_estimator)(
                 self.estimator, X, y[:, i], weights
@@ -517,7 +517,7 @@ class Model(TransformerMixin, BaseEstimator):
                  Set to None to perform no selection.
       validation_size: Fraction (number between 0 and 1) of the data to take for
                        validation and systematic uncertainty estimate.
-      bnn: Use BNN?
+      model_type: Which model to use. "bnn" for a BNN, "ridge" for Ridge and "ard" for ARD.
 
     """
     def __init__(self,
@@ -529,11 +529,11 @@ class Model(TransformerMixin, BaseEstimator):
                  tof_start: Optional[int]=None,
                  delta_tof: Optional[int]=300,
                  validation_size: float=0.05,
-                 bnn: bool=True,
+                 model_type: Literal["bnn", "ridge", "ard"]="ard",
                 ):
         self.high_res_sigma = high_res_sigma
         # models
-        self.x_select = SelectRelevantLowResolution(channels, tof_start, delta_tof, poly=not bnn)
+        self.x_select = SelectRelevantLowResolution(channels, tof_start, delta_tof, poly=(model_type not in ["bnn"]))
         x_model_steps = list()
         x_model_steps += [
                           ('pca', PCA(n_pca_lr, whiten=True)),
@@ -547,11 +547,13 @@ class Model(TransformerMixin, BaseEstimator):
                                 ])
         self.ood = {ch: UncorrelatedDeviation(sigma=5)
                     for ch in channels+['full']}
-        if bnn:
+        if model_type == "bnn":
             self.fit_model = BNNModel()
-        else:
-            self.fit_model = MultiOutputWithStd(BayesianRidge(n_iter=300, tol=1e-8, verbose=True), n_jobs=8)
-        self.bnn = bnn
+        elif model_type == "ridge":
+            self.fit_model = MultiOutputWithStd(BayesianRidge(n_iter=300, verbose=True), n_jobs=8)
+        elif model_type == "ard":
+            self.fit_model = MultiOutputWithStd(ARDRegression(n_iter=300, verbose=True), n_jobs=8)
+        self.model_type = model_type
 
         self.kde_xgm = None
         self.mu_xgm = np.nan
@@ -640,8 +642,6 @@ class Model(TransformerMixin, BaseEstimator):
 
         Returns: Smoothened high resolution spectrum.
         """
-        if weights is None:
-            weights = np.ones(high_res_data.shape[0])
         print("Fitting PCA on low-resolution data.")
         self.x_select.fit(low_res_data)
         low_res_select = self.x_select.transform(low_res_data, pulse_energy=pulse_energy)
@@ -659,7 +659,9 @@ class Model(TransformerMixin, BaseEstimator):
         self.ood['full'].fit(x_t)
         inliers = self.ood['full'].predict(x_t) > 0.0
         print("Fitting model.")
-        self.fit_model.fit(x_t[inliers], y_t[inliers], weights[inliers])
+        if weights is not None:
+            weights = weights[inliers]
+        self.fit_model.fit(x_t[inliers], y_t[inliers], weights)
 
         # calculate the effect of the PCA
         print("Calculate PCA unc. on high-resolution data.")
@@ -913,7 +915,7 @@ class Model(TransformerMixin, BaseEstimator):
         joblib.dump([self.x_select,
                      self.x_model,
                      self.y_model,
-                     self.fit_model.state_dict() if self.bnn else self.fit_model,
+                     self.fit_model.state_dict() if self.model_type == "bnn" else self.fit_model,
                      self.channel_pca,
                      #self.channel_fit_model
                      DataHolder(dict(
@@ -926,7 +928,7 @@ class Model(TransformerMixin, BaseEstimator):
                                      resolution=self.resolution,
                                      transfer_function=self.transfer_function,
                                      impulse_response=self.impulse_response,
-                                     bnn=self.bnn,
+                                     model_type=self.model_type,
                                     )
                                ),
                      self.ood,
@@ -963,12 +965,12 @@ class Model(TransformerMixin, BaseEstimator):
         obj.resolution = extra["resolution"]
         obj.transfer_function = extra["transfer_function"]
         obj.impulse_response = extra["impulse_response"]
-        obj.bnn = extra["bnn"]
+        obj.model_type = extra["model_type"]
 
         obj.x_select = x_select
         obj.x_model = x_model
         obj.y_model = y_model
-        if obj.bnn:
+        if obj.model_type == "bnn":
             obj.fit_model = BNNModel(state_dict=fit_model)
         else:
             obj.fit_model = fit_model
diff --git a/pes_to_spec/test/offline_analysis.py b/pes_to_spec/test/offline_analysis.py
index 5ede9532259921dd4033e2f67dfc5492073a7ac7..20d67494114742e6ef083007287cc88e4365b713 100755
--- a/pes_to_spec/test/offline_analysis.py
+++ b/pes_to_spec/test/offline_analysis.py
@@ -142,8 +142,8 @@ def main():
     parser.add_argument('-P', '--pes', type=str, metavar='NAME', default="SA3_XTD10_PES/ADC/1:network", help='PES name')
     parser.add_argument('-X', '--xgm', type=str, metavar='NAME', default="SA3_XTD10_XGM/XGM/DOOCS:output", help='XGM name')
     parser.add_argument('-o', '--offset', type=int, metavar='INT', default=0, help='Train ID offset')
-    parser.add_argument('-c', '--xgm_cut', type=float, metavar='INTENSITY', default=500, help='XGM intensity threshold in uJ.')
-    parser.add_argument('-e', '--bnn', action="store_true", default=False, help='Use BNN?')
+    parser.add_argument('-c', '--xgm_cut', type=float, metavar='INTENSITY', default=0, help='XGM intensity threshold in uJ.')
+    parser.add_argument('-T', '--model-type', type=str, metavar='TYPE', default="ard", choices=["bnn", "ridge", "ard"], help='Which model type to use.')
     parser.add_argument('-w', '--weight', action="store_true", default=False, help='Whether to reweight data as a function of the pulse energy to make it invariant to that.')
 
     args = parser.parse_args()
@@ -207,6 +207,8 @@ def main():
         test_tids = matching_ids(spec_tidt, pes_tidt, xgm_tidt)
     else:
         test_tids = tids
+    print(f"Number of train IDs: {len(train_tids)}")
+    print(f"Number of test IDs: {len(test_tids)}")
 
     # read the PES data for each channel
     channels = [f"channel_{i}_{l}"
@@ -234,7 +236,7 @@ def main():
     t = list()
     t_names = list()
 
-    model = Model(bnn=args.bnn)
+    model = Model(model_type=args.model_type)
 
     train_idx = np.isin(tids, train_tids) & (xgm_flux[:,0] > args.xgm_cut)
     # we just need this for training and we need to avoid copying it, which blows up the memoray usage