From f6379fc6a91b8764e38b3fba0b4a45e9f143ef62 Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Mon, 9 Oct 2023 10:25:41 +0200
Subject: [PATCH] Made PyTorch optional, so it runs even without it.

---
 .gitlab-ci.yml           | 31 +++++++++++++++++++++++++++++++
 pes_to_spec/bnn.py       | 18 +++++++++++++-----
 pes_to_spec/exception.py | 13 +++++++++++++
 pes_to_spec/model.py     | 17 +++++++++++++++--
 4 files changed, 72 insertions(+), 7 deletions(-)
 create mode 100644 .gitlab-ci.yml
 create mode 100644 pes_to_spec/exception.py

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
new file mode 100644
index 0000000..dd1851a
--- /dev/null
+++ b/.gitlab-ci.yml
@@ -0,0 +1,31 @@
+stages:
+  - environment
+  - test
+  - deploy
+
+cache:
+  paths:
+    - .cache/pip
+    - .venv
+
+setup-environment:
+  stage: environment
+  image: python3.9-slim-buster
+  script:
+    - python3 -m venv .venv
+    - source .venv/bin/activate
+    - python3 -m pip install --upgrade pip
+    - python3 -m pip install --force-reinstall --index-url https://pypi.anaconda.org/intel/simple --no-dependencies numpy scipy==1.7.3
+    - python3 -m pip install numpy scipy==1.7.3
+    #- python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
+    - python3 -m pip install joblib scikit-learn
+    - python3 -m pip install matplotlib lmfit seaborn extra_data
+
+test_ard:
+  stage: test
+  image: python3.9-slim-buster
+  script:
+    - python3 -m venv .venv
+    - source .venv/bin/activate
+    - ./pes_to_spec/test/offline_analysis.py -p 900331 -r 69 -t 70 -d results_ard --model-type ard
+
diff --git a/pes_to_spec/bnn.py b/pes_to_spec/bnn.py
index 5f8832e..242a507 100644
--- a/pes_to_spec/bnn.py
+++ b/pes_to_spec/bnn.py
@@ -1,14 +1,22 @@
+"""
+BNN implementation.
+"""
+
 from sklearn.base import BaseEstimator, RegressorMixin
 from typing import Any, Dict, Optional, Union, Tuple
 
 import numpy as np
-import math
 from scipy.special import gamma
 
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.utils.data import TensorDataset, DataLoader
+from pes_to_spec.exception import MethodNotAvailableException
+
+try:
+    import torch
+    import torch.nn as nn
+    import torch.nn.functional as F
+    from torch.utils.data import TensorDataset, DataLoader
+except ImportError:
+    raise MethodNotAvailableException("PyTorch not available. BNN deactivated.")
 
 class BayesLinearEmpiricalPrior(nn.Module):
     """
diff --git a/pes_to_spec/exception.py b/pes_to_spec/exception.py
new file mode 100644
index 0000000..75a169d
--- /dev/null
+++ b/pes_to_spec/exception.py
@@ -0,0 +1,13 @@
+"""
+Module containing package-specific exceptions.
+"""
+
+class MethodNotAvailableException(Exception):
+    """
+    Flags that one of the methods is not available.
+    """
+    def __init__(self, msg: str="Method not available."):
+        self.msg = msg
+    def __str__(self) -> str:
+        return self.msg
+
diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 7f2fe33..99734a3 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -1,5 +1,7 @@
 from __future__ import annotations
 
+from typing import Any, Dict, List, Optional, Union, Tuple, Literal
+
 import joblib
 
 import numpy as np
@@ -19,9 +21,15 @@ from sklearn.base import clone, MetaEstimatorMixin
 from joblib import Parallel, delayed
 from copy import deepcopy
 
-from pes_to_spec.bnn import BNNModel
+from pes_to_spec.exception import MethodNotAvailableException
+
+is_bnn_available = False
+try:
+    from pes_to_spec.bnn import BNNModel
+    is_bnn_available = True
+except MethodNotAvailableException:
+    print("Warning: BNN model disabled. It requires PyTorch. Check if it is installed.")
 
-from typing import Any, Dict, List, Optional, Union, Tuple, Literal
 
 def matching_ids(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
     """Returns list of train IDs common to sets a, b and c."""
@@ -629,6 +637,8 @@ class Model(TransformerMixin, BaseEstimator):
                  n_peaks: int=0,
                  n_bnn_epochs: int=500,
                 ):
+        if model_type in ["bnn", "bnn_rvm"] and not is_bnn_available:
+            raise MethodNotAvailableException("The BNN model requires a PyTorch installation. Please do `pip install torch` or `conda install pytorch` to be able to use the BNN model.")
         self.high_res_sigma = high_res_sigma
         # models
         self.x_select = SelectRelevantLowResolution(channels, tof_start, delta_tof, poly=False) #(model_type not in ["bnn", "bnn_rvm"]))
@@ -1090,6 +1100,9 @@ class Model(TransformerMixin, BaseEstimator):
         obj.x_select = x_select
         obj.x_model = x_model
         obj.y_model = y_model
+        if obj.model_type in ["bnn", "bnn_rvm"] and not is_bnn_available:
+            raise MethodNotAvailableException("Attempted to load a BNN model, but it requires a PyTorch installation. "
+                                              "Please do `pip install torch` or `conda install pytorch` to be able to load this model.")
         if obj.model_type == "bnn":
             obj.fit_model = BNNModel(state_dict=fit_model, rvm=False)
         elif obj.model_type == "bnn_rvm":
-- 
GitLab