Skip to content
Snippets Groups Projects
Commit 8dc714e1 authored by Martin Teichmann's avatar Martin Teichmann
Browse files

improve the centroiding code

This inlines the centroiding code into the HRIXS class, and also improves the readability.

See merge request !233
parents 7a4075f0 bceda8ec
No related branches found
No related tags found
1 merge request!233improve the centroiding code
...@@ -171,64 +171,6 @@ THRESHOLD = 510 # pixel counts above which a hit candidate is assumed ...@@ -171,64 +171,6 @@ THRESHOLD = 510 # pixel counts above which a hit candidate is assumed
CURVE_A = 2.19042931e-02 # curvature parameters as determined elsewhere CURVE_A = 2.19042931e-02 # curvature parameters as determined elsewhere
CURVE_B = -3.02191568e-07 CURVE_B = -3.02191568e-07
def _esrf_centroid(image, threshold=THRESHOLD, curvature=(CURVE_A, CURVE_B)):
gs = 2
base = image.mean()
cp = np.argwhere(image[gs // 2: -gs // 2, gs // 2: -gs // 2] > threshold) + np.array([gs // 2, gs // 2])
if len(cp) > 100000:
raise RuntimeError('Threshold too low or acquisition time too long')
res = []
for cy, cx in cp:
spot = image[cy - gs // 2: cy + gs // 2 + 1, cx - gs // 2: cx + gs // 2 + 1] - base
spot[spot < 0] = 0
if (spot > image[cy, cx]).sum() == 0:
mx = np.average(np.arange(cx - gs // 2, cx + gs // 2 + 1), weights=spot.sum(axis=0))
my = np.average(np.arange(cy - gs // 2, cy + gs // 2 + 1), weights=spot.sum(axis=1))
my -= (curvature[0] + curvature[1] * mx) * mx
res.append((my, mx))
return res
def _new_centroid(image, threshold=THRESHOLD, std_threshold=3.5, curvature=(CURVE_A, CURVE_B)):
"""find the position of photons with sub-pixel precision
A photon is supposed to have hit the detector if the intensity within a
2-by-2 square exceeds a threshold. In this case the position of the photon
is calculated as the center-of-mass in a 4-by-4 square.
Return the list of x,y coordinate pairs, corrected by the curvature.
"""
base = image.mean()
corners = image[1:, 1:] + image[:-1, 1:] + image[1:, :-1] + image[:-1, :-1]
if threshold is None:
threshold = corners.mean() + std_threshold * corners.std()
middle = corners[1:-1, 1:-1]
candidates = (
(middle > threshold)
* (middle >= corners[:-2, 1:-1]) * (middle > corners[2:, 1:-1])
* (middle >= corners[1:-1, :-2]) * (middle > corners[1:-1, 2:])
* (middle >= corners[:-2, :-2]) * (middle > corners[2:, :-2])
* (middle >= corners[:-2, 2:]) * (middle > corners[2:, 2:]))
cp = np.argwhere(candidates)
if len(cp) > 10000:
raise RuntimeError(
"too many peaks, threshold too low or acquisition time too high")
res = []
for cy, cx in cp:
spot = image[cy: cy + 4, cx: cx + 4] - base
mx = np.average(np.arange(cx, cx + 4), weights=spot.sum(axis=0))
my = np.average(np.arange(cy, cy + 4), weights=spot.sum(axis=1))
my -= (curvature[0] + curvature[1] * mx) * mx
res.append((my, mx))
return res
centroid = _new_centroid
def decentroid(res): def decentroid(res):
res = np.array(res) res = np.array(res)
ret = np.zeros(shape=(res.max(axis=0) + 1).astype(int)) ret = np.zeros(shape=(res.max(axis=0) + 1).astype(int))
...@@ -410,6 +352,42 @@ class hRIXS: ...@@ -410,6 +352,42 @@ class hRIXS:
self.CURVE_B, self.CURVE_A, *_ = args self.CURVE_B, self.CURVE_A, *_ = args
return self.CURVE_A, self.CURVE_B return self.CURVE_A, self.CURVE_B
def centroid_one(self, image):
"""find the position of photons with sub-pixel precision
A photon is supposed to have hit the detector if the intensity within a
2-by-2 square exceeds a threshold. In this case the position of the photon
is calculated as the center-of-mass in a 4-by-4 square.
Return the list of x, y coordinate pairs, corrected by the curvature.
"""
base = image.mean()
corners = image[1:, 1:] + image[:-1, 1:] \
+ image[1:, :-1] + image[:-1, :-1]
if self.THRESHOLD is None:
threshold = corners.mean() + self.STD_THRESHOLD * corners.std()
else:
threshold = self.THRESHOLD
middle = corners[1:-1, 1:-1]
candidates = (
(middle > threshold)
* (middle >= corners[:-2, 1:-1]) * (middle > corners[2:, 1:-1])
* (middle >= corners[1:-1, :-2]) * (middle > corners[1:-1, 2:])
* (middle >= corners[:-2, :-2]) * (middle > corners[2:, :-2])
* (middle >= corners[:-2, 2:]) * (middle > corners[2:, 2:]))
cp = np.argwhere(candidates)
if len(cp) > 10000:
raise RuntimeError(
"too many peaks, threshold low or acquisition time too high")
res = []
for cy, cx in cp:
spot = image[cy: cy + 4, cx: cx + 4] - base
mx = np.average(np.arange(cx, cx + 4), weights=spot.sum(axis=0))
my = np.average(np.arange(cy, cy + 4), weights=spot.sum(axis=1))
res.append((mx, my))
return res
def centroid(self, data, bins=None): def centroid(self, data, bins=None):
"""calculate a spectrum by finding the centroid of individual photons """calculate a spectrum by finding the centroid of individual photons
...@@ -420,30 +398,27 @@ class hRIXS: ...@@ -420,30 +398,27 @@ class hRIXS:
Example Example
------- -------
data = h.centroid(data) # find photons in all images of the run h.centroid(data) # find photons in all images of the run
data.spectrum[0, :].plot() # plot the spectrum of the first image data.spectrum[0, :].plot() # plot the spectrum of the first image
""" """
if bins is None: if bins is None:
bins = self.BINS bins = self.BINS
ret = np.zeros((len(data["hRIXS_det"]), bins)) ret = np.zeros((len(data["hRIXS_det"]), bins))
for image, r in zip(data["hRIXS_det"], ret): for image, r in zip(data["hRIXS_det"], ret):
c = centroid( c = self.centroid_one(
image.values[self.X_RANGE, self.Y_RANGE].T, image.values[self.X_RANGE, self.Y_RANGE])
threshold=self.THRESHOLD,
std_threshold=self.STD_THRESHOLD,
curvature=(self.CURVE_A, self.CURVE_B))
if not len(c): if not len(c):
continue continue
rc = np.array(c) rc = np.array(c)
hy, hx = np.histogram( r[:], _ = np.histogram(
rc[:, 0], bins=bins, rc[:, 0] - self.parabola(rc[:, 1]),
range=(0, self.Y_RANGE.stop - self.Y_RANGE.start)) bins=bins, range=(0, self.Y_RANGE.stop - self.Y_RANGE.start))
r[:] = hy
data = data.assign_coords( data.coords["energy"] = (
energy=np.linspace(self.Y_RANGE.start, self.Y_RANGE.stop, bins) np.linspace(self.Y_RANGE.start, self.Y_RANGE.stop, bins)
* self.ENERGY_SLOPE + self.ENERGY_INTERCEPT) * self.ENERGY_SLOPE + self.ENERGY_INTERCEPT)
return data.assign(spectrum=(("trainId", "energy"), ret)) data['spectrum'] = (("trainId", "energy"), ret)
return data
def parabola(self, x): def parabola(self, x):
return (self.CURVE_B * x + self.CURVE_A) * x return (self.CURVE_B * x + self.CURVE_A) * x
...@@ -522,7 +497,7 @@ class hRIXS: ...@@ -522,7 +497,7 @@ class hRIXS:
Example Example
------- -------
data = h.centroid(data) # create spectra from finding photons h.centroid(data) # create spectra from finding photons
agg = h.aggregate(data) # sum all spectra agg = h.aggregate(data) # sum all spectra
agg.spectrum.plot() # plot the resulting spectrum agg.spectrum.plot() # plot the resulting spectrum
......
...@@ -22,5 +22,44 @@ class TestHRIXS(unittest.TestCase): ...@@ -22,5 +22,44 @@ class TestHRIXS(unittest.TestCase):
28517.704705882363) 28517.704705882363)
self.assertEqual(data['spectrum'][1, 50].coords['energy'], 90) self.assertEqual(data['spectrum'][1, 50].coords['energy'], 90)
def test_centroid(self):
data = xa.Dataset()
img = np.array([
[[0, 0, 0, 0, 0, 0, 0,],
[0, 0, 0, 0, 0, 0, 0,],
[0, 0, 1, 1, 0, 0, 0,],
[0, 0, 1, 1, 0, 0, 0,],
[0, 0, 0, 0, 0, 0, 0,],
[0, 0, 0, 0, 0, 0, 0,],
[0, 0, 0, 0, 0, 0, 0,],],
[[0, 0, 0, 0, 0, 0, 0,],
[0, 0, 0, 0, 0, 0, 0,],
[0, 0, 1, 1, 2, 0, 0,],
[0, 0, 1, 7, 2, 0, 0,],
[0, 0, 1, 1, 2, 0, 0,],
[0, 0, 0, 0, 0, 0, 0,],
[0, 0, 0, 0, 0, 0, 0,],],
])
data['hRIXS_det'] = (('trainId', 'x', 'y'), img)
h = hRIXS()
h.Y_RANGE = slice(0, 7)
h.THRESHOLD = 0.5
h.BINS = 10
data = h.centroid(data)
assert_array_equal(data['spectrum'], [
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
])
h.CURVE_A = 0.1
h.CURVE_B = 0.01
r = h.centroid(data)
self.assertIs(r, data)
assert_array_equal(data['spectrum'], [
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment