Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
T
ToolBox
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
SCS
ToolBox
Commits
60c5e70c
Commit
60c5e70c
authored
1 year ago
by
Loïc Le Guyader
Browse files
Options
Downloads
Patches
Plain Diff
cleanup and minimize code
parent
c9f25f29
No related branches found
Branches containing commit
No related tags found
1 merge request
!280
WIP: First RIXS with JUNGFRAU detector implementation
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/toolbox_scs/detectors/jf_hrixs.py
+29
-291
29 additions, 291 deletions
src/toolbox_scs/detectors/jf_hrixs.py
with
29 additions
and
291 deletions
src/toolbox_scs/detectors/jf_hrixs.py
+
29
−
291
View file @
60c5e70c
from
functools
import
lru_cache
import
xarray
as
xr
import
xarray
as
xr
import
numpy
as
np
import
numpy
as
np
import
matplotlib.pyplot
as
plt
import
matplotlib.pyplot
as
plt
from
scipy.optimize
import
leastsq
from
scipy.optimize
import
leastsq
from
scipy.optimize
import
curve_fit
from
scipy.signal
import
fftconvolve
import
toolbox_scs
as
tb
import
toolbox_scs
as
tb
...
@@ -14,111 +11,13 @@ __all__ = [
...
@@ -14,111 +11,13 @@ __all__ = [
'
JF_hRIXS
'
,
'
JF_hRIXS
'
,
]
]
# -----------------------------------------------------------------------------
# Curvature
def
correct_curvature
(
image
,
factor
=
None
,
axis
=
1
):
if
factor
is
None
:
return
if
axis
==
1
:
image
=
image
.
T
ydim
,
xdim
=
image
.
shape
x
=
np
.
arange
(
xdim
+
1
)
y
=
np
.
arange
(
ydim
+
1
)
xx
,
yy
=
np
.
meshgrid
(
x
[:
-
1
]
+
0.5
,
y
[:
-
1
]
+
0.5
)
xxn
=
xx
-
factor
[
0
]
*
yy
-
factor
[
1
]
*
yy
**
2
ret
=
np
.
histogramdd
((
xxn
.
flatten
(),
yy
.
flatten
()),
bins
=
[
x
,
y
],
weights
=
image
.
flatten
())[
0
]
return
ret
if
axis
==
1
else
ret
.
T
def
get_spectrum
(
image
,
factor
=
None
,
axis
=
0
,
pixel_range
=
None
,
energy_range
=
None
,
):
start
,
stop
=
(
0
,
image
.
shape
[
axis
-
1
])
if
pixel_range
is
not
None
:
start
=
max
(
pixel_range
[
0
]
or
start
,
start
)
stop
=
min
(
pixel_range
[
1
]
or
stop
,
stop
)
edge
=
image
.
sum
(
axis
=
axis
)[
start
:
stop
]
bins
=
np
.
arange
(
start
,
stop
+
1
)
centers
=
(
bins
[
1
:]
+
bins
[:
-
1
])
*
0.5
if
factor
is
not
None
:
centers
,
edge
=
calibrate
(
centers
,
edge
,
factor
=
factor
,
range_
=
energy_range
)
return
centers
,
edge
# -----------------------------------------------------------------------------
# Energy calibration
def
energy_calibration
(
channels
,
energies
):
return
np
.
polyfit
(
channels
,
energies
,
deg
=
1
)
def
calibrate
(
x
,
y
=
None
,
factor
=
None
,
range_
=
None
):
if
factor
is
not
None
:
x
=
np
.
polyval
(
factor
,
x
)
if
y
is
not
None
and
range_
is
not
None
:
start
=
np
.
argmin
(
np
.
abs
((
x
-
range_
[
0
])))
stop
=
np
.
argmin
(
np
.
abs
((
x
-
range_
[
1
])))
# Calibrated energies have a different direction
x
,
y
=
x
[
stop
:
start
],
y
[
stop
:
start
]
return
x
,
y
# -----------------------------------------------------------------------------
# Gaussian-related functions
FWHM_COEFF
=
2
*
np
.
sqrt
(
2
*
np
.
log
(
2
))
FWHM_COEFF
=
2
*
np
.
sqrt
(
2
*
np
.
log
(
2
))
def
gaussian_fit
(
x_data
,
y_data
,
offset
=
0
):
"""
Centre-of-mass and width. Lifted from image_processing.imageCentreofMass()
"""
x0
=
np
.
average
(
x_data
,
weights
=
y_data
)
sx
=
np
.
sqrt
(
np
.
average
((
x_data
-
x0
)
**
2
,
weights
=
y_data
))
# Gaussian fit
baseline
=
y_data
.
min
()
p_0
=
(
y_data
.
max
(),
x0
+
offset
,
sx
,
baseline
)
try
:
p_f
,
_
=
curve_fit
(
gauss1d
,
x_data
,
y_data
,
p_0
,
maxfev
=
10000
)
return
p_f
except
(
RuntimeError
,
TypeError
)
as
e
:
print
(
e
)
return
None
def
gauss1d
(
x
,
height
,
x0
,
sigma
,
offset
):
return
height
*
np
.
exp
(
-
0.5
*
((
x
-
x0
)
/
sigma
)
**
2
)
+
offset
def
to_fwhm
(
sigma
):
def
to_fwhm
(
sigma
):
return
abs
(
sigma
*
FWHM_COEFF
)
return
abs
(
sigma
*
FWHM_COEFF
)
def
decentroid
(
res
):
res
=
np
.
array
(
res
)
ret
=
np
.
zeros
(
shape
=
(
res
.
max
(
axis
=
0
)
+
1
).
astype
(
int
))
for
cy
,
cx
in
res
:
if
cx
>
0
and
cy
>
0
:
ret
[
int
(
cy
),
int
(
cx
)]
+=
1
return
ret
class
JF_hRIXS
:
class
JF_hRIXS
:
"""
The JUNGFRAU hRIXS analysis, especially curvature correction
"""
The JUNGFRAU hRIXS analysis, especially curvature correction
...
@@ -144,7 +43,7 @@ class JF_hRIXS:
...
@@ -144,7 +43,7 @@ class JF_hRIXS:
STD_THRESHOLD:
STD_THRESHOLD:
same as THRESHOLD, in standard deviations.
same as THRESHOLD, in standard deviations.
DBL_THRESHOLD:
DBL_THRESHOLD:
threshold controling whether a detected hit is considered to be a
threshold controling whether a detected hit is considered to be a
double hit.
double hit.
BINS: int
BINS: int
the number of bins used in centroiding
the number of bins used in centroiding
...
@@ -161,7 +60,7 @@ class JF_hRIXS:
...
@@ -161,7 +60,7 @@ class JF_hRIXS:
Example
Example
-------
-------
proposal = 3145
proposal = 3145
h = hRIXS(proposal)
h = hRIXS(proposal)
h.Y_RANGE = slice(700, 900)
h.Y_RANGE = slice(700, 900)
...
@@ -172,7 +71,7 @@ class JF_hRIXS:
...
@@ -172,7 +71,7 @@ class JF_hRIXS:
"""
"""
def
__init__
(
self
,
proposalNB
):
def
__init__
(
self
,
proposalNB
):
self
.
PROPOSAL
=
proposalNB
self
.
PROPOSAL
=
proposalNB
# image range
# image range
self
.
X_RANGE
=
np
.
s_
[:]
self
.
X_RANGE
=
np
.
s_
[:]
...
@@ -187,7 +86,8 @@ class JF_hRIXS:
...
@@ -187,7 +86,8 @@ class JF_hRIXS:
self
.
ENERGY_INTERCEPT
=
0
self
.
ENERGY_INTERCEPT
=
0
self
.
ENERGY_SLOPE
=
1
self
.
ENERGY_SLOPE
=
1
self
.
FIELDS
=
[
'
hRIXS_det
'
,
'
hRIXS_index
'
,
'
hRIXS_delay
'
,
'
hRIXS_norm
'
,
'
nrj
'
]
self
.
FIELDS
=
[
'
hRIXS_det
'
,
'
hRIXS_index
'
,
'
hRIXS_delay
'
,
'
hRIXS_norm
'
,
'
nrj
'
]
def
set_params
(
self
,
**
params
):
def
set_params
(
self
,
**
params
):
for
key
,
value
in
params
.
items
():
for
key
,
value
in
params
.
items
():
...
@@ -200,14 +100,16 @@ class JF_hRIXS:
...
@@ -200,14 +100,16 @@ class JF_hRIXS:
'
bins
'
,
'
fields
'
)
'
bins
'
,
'
fields
'
)
return
{
param
:
getattr
(
self
,
param
.
upper
())
for
param
in
params
}
return
{
param
:
getattr
(
self
,
param
.
upper
())
for
param
in
params
}
def
from_run
(
self
,
runNB
,
proposal
=
None
,
extra_fields
=
(),
drop_first
=
False
):
def
from_run
(
self
,
runNB
,
proposal
=
None
,
extra_fields
=
(),
"""
load a run
drop_first
=
False
):
"""
Load a run.
Load the run `runNB`. A thin wrapper around `toolbox.load`.
Load the run `runNB`. A thin wrapper around `toolbox.load`.
Parameters
Parameters
----------
----------
drop_first: bool
drop_first: bool
if True, the first image in the run is removed from the dataset.
if True, the first image in the run is removed from the
dataset.
Example
Example
-------
-------
...
@@ -220,55 +122,14 @@ class JF_hRIXS:
...
@@ -220,55 +122,14 @@ class JF_hRIXS:
"""
"""
if
proposal
is
None
:
if
proposal
is
None
:
proposal
=
self
.
PROPOSAL
proposal
=
self
.
PROPOSAL
run
,
data
=
tb
.
load
(
proposal
,
runNB
=
runNB
,
_
,
data
=
tb
.
load
(
proposal
,
runNB
=
runNB
,
fields
=
self
.
FIELDS
+
list
(
extra_fields
))
fields
=
self
.
FIELDS
+
list
(
extra_fields
))
if
drop_first
is
True
:
if
drop_first
is
True
:
data
=
data
.
isel
(
trainId
=
slice
(
1
,
None
))
data
=
data
.
isel
(
trainId
=
slice
(
1
,
None
))
return
data
return
data
def
find_curvature
(
self
,
runNB
,
proposal
=
None
,
plot
=
True
,
args
=
None
,
def
find_curvature
(
self
,
img
,
args
,
plot
=
False
,
**
kwargs
):
**
kwargs
):
"""
Find the curvature correction coefficients.
"""
find the curvature correction coefficients
The hRIXS has some abberations which leads to the spectroscopic lines
being curved on the detector. We approximate these abberations with
a parabola for later correction.
Load a run and determine the curvature. The curvature is set in `self`,
and returned as a pair of floats.
Parameters
----------
runNB: int
the run number to use
proposal: int
the proposal to use, default to the current proposal
plot: bool
whether to plot the found curvature onto the data
args: pair of float, optional
a starting value to prime the fitting routine
Example
-------
h.find_curvature(155) # use run 155 to fit the curvature
"""
data
=
self
.
from_run
(
runNB
,
proposal
)
image
=
data
[
'
hRIXS_det
'
].
sum
(
dim
=
'
trainId
'
)
\
.
values
[
self
.
X_RANGE
,
self
.
Y_RANGE
].
T
if
args
is
None
:
spec
=
(
image
-
image
[:
10
,
:].
mean
()).
mean
(
axis
=
1
)
mean
=
np
.
average
(
np
.
arange
(
len
(
spec
)),
weights
=
spec
)
args
=
(
-
2e-7
,
0.02
,
mean
-
0.02
*
image
.
shape
[
1
]
/
2
,
3
,
spec
.
max
(),
image
.
mean
())
args
=
_find_curvature
(
image
,
args
,
plot
=
plot
,
**
kwargs
)
self
.
CURVE_B
,
self
.
CURVE_A
,
*
_
=
args
return
self
.
CURVE_A
,
self
.
CURVE_B
def
find_curvature
(
img
,
args
,
plot
=
False
,
**
kwargs
):
"""
find the curvature correction coefficients
The hRIXS has some abberations which leads to the spectroscopic lines
The hRIXS has some abberations which leads to the spectroscopic lines
being curved on the detector. We approximate these abberations with
being curved on the detector. We approximate these abberations with
...
@@ -279,7 +140,6 @@ class JF_hRIXS:
...
@@ -279,7 +140,6 @@ class JF_hRIXS:
Parameters
Parameters
----------
----------
img: array
img: array
2D average image
2D average image
args: (a, b, c, s, h, o) initial coefficients
args: (a, b, c, s, h, o) initial coefficients
...
@@ -287,51 +147,57 @@ class JF_hRIXS:
...
@@ -287,51 +147,57 @@ class JF_hRIXS:
h the height and o an offset
h the height and o an offset
plot: bool
plot: bool
whether to plot the found curvature onto the data
whether to plot the found curvature onto the data
Example
Example
-------
-------
h.find_curvature(155) # use run 155 to fit the curvature
h.find_curvature(155) # use run 155 to fit the curvature
"""
"""
def
parabola
(
x
,
a
,
b
,
c
,
s
=
0
,
h
=
0
,
o
=
0
):
def
parabola
(
x
,
a
,
b
,
c
,
s
=
0
,
h
=
0
,
o
=
0
):
return
(
a
*
x
+
b
)
*
x
+
c
return
(
a
*
x
+
b
)
*
x
+
c
def
gauss
(
y
,
x
,
a
,
b
,
c
,
s
,
h
,
o
=
0
):
def
gauss
(
y
,
x
,
a
,
b
,
c
,
s
,
h
,
o
=
0
):
return
h
*
np
.
exp
(
-
((
y
-
parabola
(
x
,
a
,
b
,
c
))
/
(
2
*
s
))
**
2
)
+
o
return
h
*
np
.
exp
(
-
((
y
-
parabola
(
x
,
a
,
b
,
c
))
/
(
2
*
s
))
**
2
)
+
o
x
=
np
.
arange
(
img
.
shape
[
1
])[
None
,
:]
x
=
np
.
arange
(
img
.
shape
[
1
])[
None
,
:]
y
=
np
.
arange
(
img
.
shape
[
0
])[:,
None
]
y
=
np
.
arange
(
img
.
shape
[
0
])[:,
None
]
if
plot
:
if
plot
:
plt
.
figure
(
figsize
=
(
10
,
10
))
plt
.
figure
(
figsize
=
(
10
,
10
))
plt
.
imshow
(
img
,
cmap
=
'
gray
'
,
aspect
=
'
auto
'
,
interpolation
=
'
nearest
'
,
**
kwargs
)
plt
.
imshow
(
img
,
cmap
=
'
gray
'
,
aspect
=
'
auto
'
,
interpolation
=
'
nearest
'
,
**
kwargs
)
plt
.
plot
(
x
[
0
,
:],
parabola
(
x
[
0
,
:],
*
args
))
plt
.
plot
(
x
[
0
,
:],
parabola
(
x
[
0
,
:],
*
args
))
args
,
_
=
leastsq
(
lambda
args
:
(
gauss
(
y
,
x
,
*
args
)
-
img
).
ravel
(),
args
)
args
,
_
=
leastsq
(
lambda
args
:
(
gauss
(
y
,
x
,
*
args
)
-
img
).
ravel
(),
args
)
if
plot
:
if
plot
:
plt
.
plot
(
x
[
0
,
:],
parabola
(
x
[
0
,
:],
*
args
))
plt
.
plot
(
x
[
0
,
:],
parabola
(
x
[
0
,
:],
*
args
))
return
args
return
args
def
parabola
(
self
,
x
):
return
(
self
.
CURVE_B
*
x
+
self
.
CURVE_A
)
*
x
def
spectrum
(
self
,
fname
):
def
spectrum
(
self
,
fname
):
"""
Bin photon hit data into spectrum.
"""
Bin photon hit data into spectrum.
Parameters
Parameters
----------
----------
fname: string
fname: string
file name of the data to load.
file name of the data to load.
"""
"""
data_interp
=
xr
.
load_dataset
(
fname
)
data_interp
=
xr
.
load_dataset
(
fname
)
def
hist_curv
(
x
,
y
):
def
hist_curv
(
x
,
y
):
H
,
_
=
np
.
histogram
(
H
,
_
=
np
.
histogram
(
x
-
self
.
parabola
(
y
),
bins
=
self
.
BINS
,
x
-
self
.
parabola
(
y
),
bins
=
self
.
BINS
,
range
=
(
0
,
self
.
Y_RANGE
.
stop
-
self
.
Y_RANGE
.
start
))
range
=
(
0
,
self
.
Y_RANGE
.
stop
-
self
.
Y_RANGE
.
start
))
return
H
return
H
energy
=
(
np
.
linspace
(
self
.
Y_RANGE
.
start
,
energy
=
(
np
.
linspace
(
self
.
Y_RANGE
.
start
,
self
.
Y_RANGE
.
stop
,
self
.
Y_RANGE
.
stop
,
self
.
BINS
)
*
self
.
ENERGY_SLOPE
+
self
.
ENERGY_INTERCEPT
)
self
.
BINS
)
*
self
.
ENERGY_SLOPE
+
self
.
ENERGY_INTERCEPT
)
spectrum
=
xr
.
apply_ufunc
(
hist_curv
,
spectrum
=
xr
.
apply_ufunc
(
hist_curv
,
data_interp
[
'
y
'
],
data_interp
[
'
y
'
],
...
@@ -347,131 +213,3 @@ class JF_hRIXS:
...
@@ -347,131 +213,3 @@ class JF_hRIXS:
spectrum
[
'
energy
'
]
=
energy
spectrum
[
'
energy
'
]
=
energy
return
spectrum
return
spectrum
def
parabola
(
self
,
x
):
return
(
self
.
CURVE_B
*
x
+
self
.
CURVE_A
)
*
x
def
integrate
(
self
,
data
):
"""
calculate a spectrum by integration
This takes the `xarray` `data` and returns a copy of it, with a new
dataarray named `spectrum` added, which contains the energy spectrum
calculated for each hRIXS image.
First the energy that corresponds to each pixel is calculated.
Then all pixels within an energy range are summed, where the intensity
of one pixel is distributed among the two energy ranges the pixel
spans, proportionally to the overlap between the pixel and bin energy
ranges.
The resulting data is normalized to one pixel, so the average
intensity that arrived on one pixel.
Example
-------
h.integrate(data) # create spectrum by summing pixels
data.spectrum[0, :].plot() # plot the spectrum of the first image
"""
bins
=
self
.
Y_RANGE
.
stop
-
self
.
Y_RANGE
.
start
margin
=
10
ret
=
np
.
zeros
((
len
(
data
[
"
hRIXS_det
"
]),
bins
-
2
*
margin
))
if
self
.
USE_DARK
:
dark_image
=
self
.
dark_image
.
values
[
self
.
X_RANGE
,
self
.
Y_RANGE
]
images
=
data
[
"
hRIXS_det
"
].
values
[:,
self
.
X_RANGE
,
self
.
Y_RANGE
]
x
,
y
=
np
.
ogrid
[:
images
.
shape
[
1
],
:
images
.
shape
[
2
]]
quo
,
rem
=
divmod
(
y
-
self
.
parabola
(
x
),
1
)
quo
=
np
.
array
([
quo
,
quo
+
1
])
rem
=
np
.
array
([
rem
,
1
-
rem
])
wrong
=
(
quo
<
margin
)
|
(
quo
>=
bins
-
margin
)
quo
[
wrong
]
=
margin
rem
[
wrong
]
=
0
quo
=
(
quo
-
margin
).
astype
(
int
).
ravel
()
for
image
,
r
in
zip
(
images
,
ret
):
if
self
.
USE_DARK
:
image
=
image
-
dark_image
r
[:]
=
np
.
bincount
(
quo
,
weights
=
(
rem
*
image
).
ravel
())
ret
/=
np
.
bincount
(
quo
,
weights
=
rem
.
ravel
())
data
.
coords
[
"
energy
"
]
=
(
np
.
arange
(
self
.
Y_RANGE
.
start
+
margin
,
self
.
Y_RANGE
.
stop
-
margin
)
*
self
.
ENERGY_SLOPE
+
self
.
ENERGY_INTERCEPT
)
data
[
'
spectrum
'
]
=
((
"
trainId
"
,
"
energy
"
),
ret
)
return
data
aggregators
=
dict
(
hRIXS_det
=
lambda
x
,
dim
:
x
.
sum
(
dim
=
dim
),
Delay
=
lambda
x
,
dim
:
x
.
mean
(
dim
=
dim
),
hRIXS_delay
=
lambda
x
,
dim
:
x
.
mean
(
dim
=
dim
),
hRIXS_norm
=
lambda
x
,
dim
:
x
.
sum
(
dim
=
dim
),
spectrum
=
lambda
x
,
dim
:
x
.
sum
(
dim
=
dim
),
dbl_spectrum
=
lambda
x
,
dim
:
x
.
sum
(
dim
=
dim
),
total_hits
=
lambda
x
,
dim
:
x
.
sum
(
dim
=
dim
),
dbl_hits
=
lambda
x
,
dim
:
x
.
sum
(
dim
=
dim
),
counts
=
lambda
x
,
dim
:
x
.
sum
(
dim
=
dim
)
)
def
aggregator
(
self
,
da
,
dim
):
agg
=
self
.
aggregators
.
get
(
da
.
name
)
if
agg
is
None
:
return
None
return
agg
(
da
,
dim
=
dim
)
def
aggregate
(
self
,
ds
,
var
=
None
,
dim
=
"
trainId
"
):
"""
aggregate (i.e. mostly sum) all data within one dataset
take all images in a dataset and aggregate them and their metadata.
For images, spectra and normalizations that means adding them, for
others (e.g. delays) adding would not make sense, so we treat them
properly. The aggregation functions of each variable are defined
in the aggregators attribute of the class.
If var is specified, group the dataset by var prior to aggregation.
A new variable
"
counts
"
gives the number of frames aggregated in
each group.
Parameters
----------
ds: xarray Dataset
the dataset containing RIXS data
var: string
One of the variables in the dataset. If var is specified, the
dataset is grouped by var prior to aggregation. This is useful
for sorting e.g. a dataset that contains multiple delays.
dim: string
the dimension over which to aggregate the data
Example
-------
h.centroid(data) # create spectra from finding photons
agg = h.aggregate(data) # sum all spectra
agg.spectrum.plot() # plot the resulting spectrum
agg2 = h.aggregate(data,
'
hRIXS_delay
'
) # group data by delay
agg2.spectrum[0, :].plot() # plot the spectrum for first value
"""
ds
[
"
counts
"
]
=
xr
.
ones_like
(
ds
[
dim
])
if
var
is
not
None
:
groups
=
ds
.
groupby
(
var
)
return
groups
.
map
(
self
.
aggregate_ds
,
dim
=
dim
)
return
self
.
aggregate_ds
(
ds
,
dim
)
def
aggregate_ds
(
self
,
ds
,
dim
=
'
trainId
'
):
ret
=
ds
.
map
(
self
.
aggregator
,
dim
=
dim
)
ret
=
ret
.
drop_vars
([
n
for
n
in
ret
if
n
not
in
self
.
aggregators
])
return
ret
def
normalize
(
self
,
data
,
which
=
"
hRIXS_norm
"
):
"""
Adds a
'
normalized
'
variable to the dataset defined as the
ration between
'
spectrum
'
and
'
which
'
Parameters
----------
data: xarray Dataset
the dataset containing hRIXS data
which: string, default=
"
hRIXS_norm
"
one of the variables of the dataset, usually
"
hRIXS_norm
"
or
"
counts
"
"""
return
data
.
assign
(
normalized
=
data
[
"
spectrum
"
]
/
data
[
which
])
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment