Skip to content
Snippets Groups Projects
Commit 60c5e70c authored by Loïc Le Guyader's avatar Loïc Le Guyader
Browse files

cleanup and minimize code

parent c9f25f29
No related branches found
No related tags found
1 merge request!280WIP: First RIXS with JUNGFRAU detector implementation
from functools import lru_cache
import xarray as xr import xarray as xr
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from scipy.optimize import leastsq from scipy.optimize import leastsq
from scipy.optimize import curve_fit
from scipy.signal import fftconvolve
import toolbox_scs as tb import toolbox_scs as tb
...@@ -14,111 +11,13 @@ __all__ = [ ...@@ -14,111 +11,13 @@ __all__ = [
'JF_hRIXS', 'JF_hRIXS',
] ]
# -----------------------------------------------------------------------------
# Curvature
def correct_curvature(image, factor=None, axis=1):
if factor is None:
return
if axis == 1:
image = image.T
ydim, xdim = image.shape
x = np.arange(xdim + 1)
y = np.arange(ydim + 1)
xx, yy = np.meshgrid(x[:-1] + 0.5, y[:-1] + 0.5)
xxn = xx - factor[0] * yy - factor[1] * yy ** 2
ret = np.histogramdd((xxn.flatten(), yy.flatten()),
bins=[x, y],
weights=image.flatten())[0]
return ret if axis == 1 else ret.T
def get_spectrum(image, factor=None, axis=0,
pixel_range=None, energy_range=None, ):
start, stop = (0, image.shape[axis - 1])
if pixel_range is not None:
start = max(pixel_range[0] or start, start)
stop = min(pixel_range[1] or stop, stop)
edge = image.sum(axis=axis)[start:stop]
bins = np.arange(start, stop + 1)
centers = (bins[1:] + bins[:-1]) * 0.5
if factor is not None:
centers, edge = calibrate(centers, edge,
factor=factor,
range_=energy_range)
return centers, edge
# -----------------------------------------------------------------------------
# Energy calibration
def energy_calibration(channels, energies):
return np.polyfit(channels, energies, deg=1)
def calibrate(x, y=None, factor=None, range_=None):
if factor is not None:
x = np.polyval(factor, x)
if y is not None and range_ is not None:
start = np.argmin(np.abs((x - range_[0])))
stop = np.argmin(np.abs((x - range_[1])))
# Calibrated energies have a different direction
x, y = x[stop:start], y[stop:start]
return x, y
# -----------------------------------------------------------------------------
# Gaussian-related functions
FWHM_COEFF = 2 * np.sqrt(2 * np.log(2)) FWHM_COEFF = 2 * np.sqrt(2 * np.log(2))
def gaussian_fit(x_data, y_data, offset=0):
"""
Centre-of-mass and width. Lifted from image_processing.imageCentreofMass()
"""
x0 = np.average(x_data, weights=y_data)
sx = np.sqrt(np.average((x_data - x0) ** 2, weights=y_data))
# Gaussian fit
baseline = y_data.min()
p_0 = (y_data.max(), x0 + offset, sx, baseline)
try:
p_f, _ = curve_fit(gauss1d, x_data, y_data, p_0, maxfev=10000)
return p_f
except (RuntimeError, TypeError) as e:
print(e)
return None
def gauss1d(x, height, x0, sigma, offset):
return height * np.exp(-0.5 * ((x - x0) / sigma) ** 2) + offset
def to_fwhm(sigma): def to_fwhm(sigma):
return abs(sigma * FWHM_COEFF) return abs(sigma * FWHM_COEFF)
def decentroid(res):
res = np.array(res)
ret = np.zeros(shape=(res.max(axis=0) + 1).astype(int))
for cy, cx in res:
if cx > 0 and cy > 0:
ret[int(cy), int(cx)] += 1
return ret
class JF_hRIXS: class JF_hRIXS:
"""The JUNGFRAU hRIXS analysis, especially curvature correction """The JUNGFRAU hRIXS analysis, especially curvature correction
...@@ -144,7 +43,7 @@ class JF_hRIXS: ...@@ -144,7 +43,7 @@ class JF_hRIXS:
STD_THRESHOLD: STD_THRESHOLD:
same as THRESHOLD, in standard deviations. same as THRESHOLD, in standard deviations.
DBL_THRESHOLD: DBL_THRESHOLD:
threshold controling whether a detected hit is considered to be a threshold controling whether a detected hit is considered to be a
double hit. double hit.
BINS: int BINS: int
the number of bins used in centroiding the number of bins used in centroiding
...@@ -161,7 +60,7 @@ class JF_hRIXS: ...@@ -161,7 +60,7 @@ class JF_hRIXS:
Example Example
------- -------
proposal = 3145 proposal = 3145
h = hRIXS(proposal) h = hRIXS(proposal)
h.Y_RANGE = slice(700, 900) h.Y_RANGE = slice(700, 900)
...@@ -172,7 +71,7 @@ class JF_hRIXS: ...@@ -172,7 +71,7 @@ class JF_hRIXS:
""" """
def __init__(self, proposalNB): def __init__(self, proposalNB):
self.PROPOSAL=proposalNB self.PROPOSAL = proposalNB
# image range # image range
self.X_RANGE = np.s_[:] self.X_RANGE = np.s_[:]
...@@ -187,7 +86,8 @@ class JF_hRIXS: ...@@ -187,7 +86,8 @@ class JF_hRIXS:
self.ENERGY_INTERCEPT = 0 self.ENERGY_INTERCEPT = 0
self.ENERGY_SLOPE = 1 self.ENERGY_SLOPE = 1
self.FIELDS = ['hRIXS_det', 'hRIXS_index', 'hRIXS_delay', 'hRIXS_norm', 'nrj'] self.FIELDS = ['hRIXS_det', 'hRIXS_index', 'hRIXS_delay',
'hRIXS_norm', 'nrj']
def set_params(self, **params): def set_params(self, **params):
for key, value in params.items(): for key, value in params.items():
...@@ -200,14 +100,16 @@ class JF_hRIXS: ...@@ -200,14 +100,16 @@ class JF_hRIXS:
'bins', 'fields') 'bins', 'fields')
return {param: getattr(self, param.upper()) for param in params} return {param: getattr(self, param.upper()) for param in params}
def from_run(self, runNB, proposal=None, extra_fields=(), drop_first=False): def from_run(self, runNB, proposal=None, extra_fields=(),
"""load a run drop_first=False):
"""Load a run.
Load the run `runNB`. A thin wrapper around `toolbox.load`. Load the run `runNB`. A thin wrapper around `toolbox.load`.
Parameters Parameters
---------- ----------
drop_first: bool drop_first: bool
if True, the first image in the run is removed from the dataset. if True, the first image in the run is removed from the
dataset.
Example Example
------- -------
...@@ -220,55 +122,14 @@ class JF_hRIXS: ...@@ -220,55 +122,14 @@ class JF_hRIXS:
""" """
if proposal is None: if proposal is None:
proposal = self.PROPOSAL proposal = self.PROPOSAL
run, data = tb.load(proposal, runNB=runNB, _, data = tb.load(proposal, runNB=runNB,
fields=self.FIELDS + list(extra_fields)) fields=self.FIELDS + list(extra_fields))
if drop_first is True: if drop_first is True:
data = data.isel(trainId=slice(1, None)) data = data.isel(trainId=slice(1, None))
return data return data
def find_curvature(self, runNB, proposal=None, plot=True, args=None, def find_curvature(self, img, args, plot=False, **kwargs):
**kwargs): """Find the curvature correction coefficients.
"""find the curvature correction coefficients
The hRIXS has some abberations which leads to the spectroscopic lines
being curved on the detector. We approximate these abberations with
a parabola for later correction.
Load a run and determine the curvature. The curvature is set in `self`,
and returned as a pair of floats.
Parameters
----------
runNB: int
the run number to use
proposal: int
the proposal to use, default to the current proposal
plot: bool
whether to plot the found curvature onto the data
args: pair of float, optional
a starting value to prime the fitting routine
Example
-------
h.find_curvature(155) # use run 155 to fit the curvature
"""
data = self.from_run(runNB, proposal)
image = data['hRIXS_det'].sum(dim='trainId') \
.values[self.X_RANGE, self.Y_RANGE].T
if args is None:
spec = (image - image[:10, :].mean()).mean(axis=1)
mean = np.average(np.arange(len(spec)), weights=spec)
args = (-2e-7, 0.02, mean - 0.02 * image.shape[1] / 2, 3,
spec.max(), image.mean())
args = _find_curvature(image, args, plot=plot, **kwargs)
self.CURVE_B, self.CURVE_A, *_ = args
return self.CURVE_A, self.CURVE_B
def find_curvature(img, args, plot=False, **kwargs):
"""find the curvature correction coefficients
The hRIXS has some abberations which leads to the spectroscopic lines The hRIXS has some abberations which leads to the spectroscopic lines
being curved on the detector. We approximate these abberations with being curved on the detector. We approximate these abberations with
...@@ -279,7 +140,6 @@ class JF_hRIXS: ...@@ -279,7 +140,6 @@ class JF_hRIXS:
Parameters Parameters
---------- ----------
img: array img: array
2D average image 2D average image
args: (a, b, c, s, h, o) initial coefficients args: (a, b, c, s, h, o) initial coefficients
...@@ -287,51 +147,57 @@ class JF_hRIXS: ...@@ -287,51 +147,57 @@ class JF_hRIXS:
h the height and o an offset h the height and o an offset
plot: bool plot: bool
whether to plot the found curvature onto the data whether to plot the found curvature onto the data
Example Example
------- -------
h.find_curvature(155) # use run 155 to fit the curvature h.find_curvature(155) # use run 155 to fit the curvature
""" """
def parabola(x, a, b, c, s=0, h=0, o=0): def parabola(x, a, b, c, s=0, h=0, o=0):
return (a*x + b)*x + c return (a*x + b)*x + c
def gauss(y, x, a, b, c, s, h, o=0): def gauss(y, x, a, b, c, s, h, o=0):
return h * np.exp(-((y - parabola(x, a, b, c)) / (2 * s))**2) + o return h * np.exp(-((y - parabola(x, a, b, c)) / (2 * s))**2) + o
x = np.arange(img.shape[1])[None, :] x = np.arange(img.shape[1])[None, :]
y = np.arange(img.shape[0])[:, None] y = np.arange(img.shape[0])[:, None]
if plot: if plot:
plt.figure(figsize=(10,10)) plt.figure(figsize=(10, 10))
plt.imshow(img, cmap='gray', aspect='auto', interpolation='nearest', **kwargs) plt.imshow(img, cmap='gray', aspect='auto',
interpolation='nearest', **kwargs)
plt.plot(x[0, :], parabola(x[0, :], *args)) plt.plot(x[0, :], parabola(x[0, :], *args))
args, _ = leastsq(lambda args: (gauss(y, x, *args) - img).ravel(), args) args, _ = leastsq(lambda args: (gauss(y, x, *args) - img).ravel(),
args)
if plot: if plot:
plt.plot(x[0, :], parabola(x[0, :], *args)) plt.plot(x[0, :], parabola(x[0, :], *args))
return args return args
def parabola(self, x):
return (self.CURVE_B * x + self.CURVE_A) * x
def spectrum(self, fname): def spectrum(self, fname):
"""Bin photon hit data into spectrum. """Bin photon hit data into spectrum.
Parameters Parameters
---------- ----------
fname: string fname: string
file name of the data to load. file name of the data to load.
""" """
data_interp = xr.load_dataset(fname) data_interp = xr.load_dataset(fname)
def hist_curv(x, y): def hist_curv(x, y):
H, _ = np.histogram( H, _ = np.histogram(
x - self.parabola(y), bins=self.BINS, x - self.parabola(y), bins=self.BINS,
range=(0, self.Y_RANGE.stop - self.Y_RANGE.start)) range=(0, self.Y_RANGE.stop - self.Y_RANGE.start))
return H return H
energy = (np.linspace(self.Y_RANGE.start, energy = (np.linspace(self.Y_RANGE.start,
self.Y_RANGE.stop, self.Y_RANGE.stop,
self.BINS) * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT) self.BINS) * self.ENERGY_SLOPE
+ self.ENERGY_INTERCEPT)
spectrum = xr.apply_ufunc(hist_curv, spectrum = xr.apply_ufunc(hist_curv,
data_interp['y'], data_interp['y'],
...@@ -347,131 +213,3 @@ class JF_hRIXS: ...@@ -347,131 +213,3 @@ class JF_hRIXS:
spectrum['energy'] = energy spectrum['energy'] = energy
return spectrum return spectrum
def parabola(self, x):
return (self.CURVE_B * x + self.CURVE_A) * x
def integrate(self, data):
"""calculate a spectrum by integration
This takes the `xarray` `data` and returns a copy of it, with a new
dataarray named `spectrum` added, which contains the energy spectrum
calculated for each hRIXS image.
First the energy that corresponds to each pixel is calculated.
Then all pixels within an energy range are summed, where the intensity
of one pixel is distributed among the two energy ranges the pixel
spans, proportionally to the overlap between the pixel and bin energy
ranges.
The resulting data is normalized to one pixel, so the average
intensity that arrived on one pixel.
Example
-------
h.integrate(data) # create spectrum by summing pixels
data.spectrum[0, :].plot() # plot the spectrum of the first image
"""
bins = self.Y_RANGE.stop - self.Y_RANGE.start
margin = 10
ret = np.zeros((len(data["hRIXS_det"]), bins - 2 * margin))
if self.USE_DARK:
dark_image = self.dark_image.values[self.X_RANGE, self.Y_RANGE]
images = data["hRIXS_det"].values[:, self.X_RANGE, self.Y_RANGE]
x, y = np.ogrid[:images.shape[1], :images.shape[2]]
quo, rem = divmod(y - self.parabola(x), 1)
quo = np.array([quo, quo + 1])
rem = np.array([rem, 1 - rem])
wrong = (quo < margin) | (quo >= bins - margin)
quo[wrong] = margin
rem[wrong] = 0
quo = (quo - margin).astype(int).ravel()
for image, r in zip(images, ret):
if self.USE_DARK:
image = image - dark_image
r[:] = np.bincount(quo, weights=(rem * image).ravel())
ret /= np.bincount(quo, weights=rem.ravel())
data.coords["energy"] = (
np.arange(self.Y_RANGE.start + margin, self.Y_RANGE.stop - margin)
* self.ENERGY_SLOPE + self.ENERGY_INTERCEPT)
data['spectrum'] = (("trainId", "energy"), ret)
return data
aggregators = dict(
hRIXS_det=lambda x, dim: x.sum(dim=dim),
Delay=lambda x, dim: x.mean(dim=dim),
hRIXS_delay=lambda x, dim: x.mean(dim=dim),
hRIXS_norm=lambda x, dim: x.sum(dim=dim),
spectrum=lambda x, dim: x.sum(dim=dim),
dbl_spectrum=lambda x, dim: x.sum(dim=dim),
total_hits=lambda x, dim: x.sum(dim=dim),
dbl_hits=lambda x, dim: x.sum(dim=dim),
counts=lambda x, dim: x.sum(dim=dim)
)
def aggregator(self, da, dim):
agg = self.aggregators.get(da.name)
if agg is None:
return None
return agg(da, dim=dim)
def aggregate(self, ds, var=None, dim="trainId"):
"""aggregate (i.e. mostly sum) all data within one dataset
take all images in a dataset and aggregate them and their metadata.
For images, spectra and normalizations that means adding them, for
others (e.g. delays) adding would not make sense, so we treat them
properly. The aggregation functions of each variable are defined
in the aggregators attribute of the class.
If var is specified, group the dataset by var prior to aggregation.
A new variable "counts" gives the number of frames aggregated in
each group.
Parameters
----------
ds: xarray Dataset
the dataset containing RIXS data
var: string
One of the variables in the dataset. If var is specified, the
dataset is grouped by var prior to aggregation. This is useful
for sorting e.g. a dataset that contains multiple delays.
dim: string
the dimension over which to aggregate the data
Example
-------
h.centroid(data) # create spectra from finding photons
agg = h.aggregate(data) # sum all spectra
agg.spectrum.plot() # plot the resulting spectrum
agg2 = h.aggregate(data, 'hRIXS_delay') # group data by delay
agg2.spectrum[0, :].plot() # plot the spectrum for first value
"""
ds["counts"] = xr.ones_like(ds[dim])
if var is not None:
groups = ds.groupby(var)
return groups.map(self.aggregate_ds, dim=dim)
return self.aggregate_ds(ds, dim)
def aggregate_ds(self, ds, dim='trainId'):
ret = ds.map(self.aggregator, dim=dim)
ret = ret.drop_vars([n for n in ret if n not in self.aggregators])
return ret
def normalize(self, data, which="hRIXS_norm"):
""" Adds a 'normalized' variable to the dataset defined as the
ration between 'spectrum' and 'which'
Parameters
----------
data: xarray Dataset
the dataset containing hRIXS data
which: string, default="hRIXS_norm"
one of the variables of the dataset, usually "hRIXS_norm"
or "counts"
"""
return data.assign(normalized=data["spectrum"] / data[which])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment