From 9172b544a32380ea48a06b00f420d17dd5f576e1 Mon Sep 17 00:00:00 2001
From: Laurent Mercadier <laurent.mercadier@xfel.eu>
Date: Fri, 8 Sep 2023 21:01:40 +0200
Subject: [PATCH] Add counts variable when aggregating data

---
 src/toolbox_scs/detectors/hrixs.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/src/toolbox_scs/detectors/hrixs.py b/src/toolbox_scs/detectors/hrixs.py
index 698da4a..87ecd84 100644
--- a/src/toolbox_scs/detectors/hrixs.py
+++ b/src/toolbox_scs/detectors/hrixs.py
@@ -1,4 +1,5 @@
 from functools import lru_cache
+import xarray as xr
 
 import numpy as np
 import matplotlib.pyplot as plt
@@ -566,7 +567,8 @@ class hRIXS:
         spectrum=lambda x, dim: x.sum(dim=dim),
         dbl_spectrum=lambda x, dim: x.sum(dim=dim),
         total_hits=lambda x, dim: x.sum(dim=dim),
-        dbl_hits=lambda x, dim: x.sum(dim=dim)
+        dbl_hits=lambda x, dim: x.sum(dim=dim),
+        counts=lambda x, dim: x.sum(dim=dim)
     )
 
     def aggregator(self, da, dim):
@@ -594,6 +596,7 @@ class hRIXS:
             agg = groups.map(h.aggregate)  # sum corresponding spectra
             agg.spectrum[0, :].plot()  # plot the spectrum for first value
         """
+        ds['counts'] = xr.ones_like(ds["trainId"])
         ret = ds.map(self.aggregator, dim=dim)
         ret = ret.drop_vars([n for n in ret if n not in self.aggregators])
         return ret
-- 
GitLab