diff --git a/src/toolbox_scs/detectors/hrixs.py b/src/toolbox_scs/detectors/hrixs.py index b8ec6e59662522c12d72b4b27ce63ebf3eda5b4b..94daaa2df153e6488ed7fc35f86921764e132798 100644 --- a/src/toolbox_scs/detectors/hrixs.py +++ b/src/toolbox_scs/detectors/hrixs.py @@ -577,13 +577,28 @@ class hRIXS: return None return agg(da, dim=dim) - def aggregate(self, ds, dim="trainId"): + def aggregate(self, ds, var=None, dim="trainId"): """aggregate (i.e. mostly sum) all data within one dataset take all images in a dataset and aggregate them and their metadata. For images, spectra and normalizations that means adding them, for others (e.g. delays) adding would not make sense, so we treat them - properly. + properly. The aggregation functions of each variable are defined + in the aggregators attribute of the class. + If var is specified, group the dataset by var prior to aggregation. + A new variable "counts" gives the number of frames aggregated in + each group. + + Parameters + ---------- + ds: xarray Dataset + the dataset containing RIXS data + var: string + One of the variables in the dataset. If var is specified, the + dataset is grouped by var prior to aggregation. This is useful + for sorting e.g. a dataset that contains multiple delays. + dim: string + the dimension over which to aggregate the data Example ------- @@ -592,11 +607,16 @@ class hRIXS: agg = h.aggregate(data) # sum all spectra agg.spectrum.plot() # plot the resulting spectrum - groups = data.groupby('hRIXS_index') # group data by a variable - agg = groups.map(h.aggregate) # sum corresponding spectra - agg.spectrum[0, :].plot() # plot the spectrum for first value + agg2 = h.aggregate(data, 'hRIXS_delay') # group data by delay + agg2.spectrum[0, :].plot() # plot the spectrum for first value """ - ds['counts'] = xr.ones_like(ds[dim]) + ds["counts"] = xr.ones_like(ds[dim]) + if var is not None: + groups = ds.groupby(var) + return groups.map(self.aggregate_ds, dim=dim) + return self.aggregate_ds(ds, dim) + + def aggregate_ds(self, ds, dim='trainId'): ret = ds.map(self.aggregator, dim=dim) ret = ret.drop_vars([n for n in ret if n not in self.aggregators]) return ret