From 3db94efae39c2028b172cdf39cb5102866f9b44b Mon Sep 17 00:00:00 2001
From: Danilo Ferreira de Lima <danilo.enoque.ferreira.de.lima@xfel.de>
Date: Fri, 3 Mar 2023 13:37:26 +0100
Subject: [PATCH] Corrected chi^2 definition.

---
 pes_to_spec/__init__.py |  2 +-
 pes_to_spec/model.py    | 34 ++++++++++++++++++++--------------
 2 files changed, 21 insertions(+), 15 deletions(-)

diff --git a/pes_to_spec/__init__.py b/pes_to_spec/__init__.py
index 0494dc2..07d6cce 100644
--- a/pes_to_spec/__init__.py
+++ b/pes_to_spec/__init__.py
@@ -2,4 +2,4 @@
 Estimate high-resolution photon spectrometer data from low-resolution non-invasive measurements.
 """
 
-VERSION = "0.2.0"
+VERSION = "0.2.1"
diff --git a/pes_to_spec/model.py b/pes_to_spec/model.py
index 38b9e36..8b47e61 100644
--- a/pes_to_spec/model.py
+++ b/pes_to_spec/model.py
@@ -651,11 +651,17 @@ class MultiOutputWithStd(MetaEstimatorMixin, BaseEstimator):
 
 class UncorrelatedDeviation(OutlierMixin, BaseEstimator):
     """
-    Detect outliers from uncorrelated whitened mean-centered inputs.
+    Detect outliers from uncorrelated inputs.
+    It uses a chi^2 sum over the features to flatten the features.
+    The standard deviation is estimated using quantiles.
+
+    Args:
+      sigma: Number of standard deviations of the chi^2 distribution.
+
     """
-    def __init__(self, contamination: float=0.003):
+    def __init__(self, sigma: float=5.0):
         super().__init__()
-        self.contamination = contamination
+        self.sigma = sigma
 
     def fit(self, X, y=None) -> OutlierMixin:
         """
@@ -667,14 +673,18 @@ class UncorrelatedDeviation(OutlierMixin, BaseEstimator):
 
         Returns: Itself.
         """
-        self.bounds_ = np.quantile(X, (self.contamination/2.0, 0.5, 1.0 - self.contamination/2.0), axis=0)
+        bounds_ = np.quantile(X, (0.003/2.0, 0.5, 1.0 - 0.003/2.0), axis=0)
+        self.ndof_ = float(X.shape[1] - 1.0)
+        self.med_ = bounds_[1, np.newaxis, ...]
+        self.sigma_ = (bounds_[2, np.newaxis, ...] - bounds_[0, np.newaxis, ...])/3.0
         return self
 
     def decision_function(self, X: np.ndarray) -> np.ndarray:
         """
         Return the decision function.
+        This is chi^2/ndof - 1 - sigma*sqrt(var_chi2)
         """
-        return self.score_samples(X) - 1.0
+        return (self.score_samples(X) - 1.0 - self.sigma*np.sqrt(2.0/self.ndof_))
 
     def score_samples(self, X: np.ndarray) -> np.ndarray:
         """
@@ -683,11 +693,9 @@ class UncorrelatedDeviation(OutlierMixin, BaseEstimator):
         Args:
           X: The new input data.
 
-        Returns: The Mahalanobis distance.
+        Returns: The chi^2 test statistic.
         """
-        med = self.bounds_[1, np.newaxis, ...]
-        sigma = (self.bounds_[2, np.newaxis, ...] - self.bounds_[0, np.newaxis, ...])/0.5
-        return np.fabs((X - med)/sigma)
+        return np.sum(((X - self.med_)/self.sigma_)**2, axis=1)/float(self.ndof_)
 
     def predict(self, X):
         """
@@ -702,10 +710,8 @@ class UncorrelatedDeviation(OutlierMixin, BaseEstimator):
             Returns -1 for anomalies/outliers and +1 for inliers.
         """
         values = self.decision_function(X)
-        is_lower = np.any(X < self.bounds_[0, np.newaxis, ...], axis=1)
-        is_upper = np.any(X > self.bounds_[2, np.newaxis, ...], axis=1)
-        is_inlier = np.full(values.shape[0], -1, dtype=int)
-        is_inlier[is_lower | is_upper] = 1
+        is_inlier = np.full(values.shape[0], 1, dtype=int)
+        is_inlier[values > 0] = -1
         return is_inlier
 
     def score(self, X, y, sample_weight=None):
@@ -784,7 +790,7 @@ class Model(TransformerMixin, BaseEstimator):
                                 ])
         #self.ood = {ch: IsolationForest(n_jobs=-1)
         #            for ch in channels+['full']}
-        self.ood = {ch: UncorrelatedDeviation(contamination=0.003)
+        self.ood = {ch: UncorrelatedDeviation(sigma=5)
                     for ch in channels+['full']}
         #self.ood = {ch: EllipticEnvelope(contamination=0.003)
         #            for ch in channels+['full']}
-- 
GitLab