Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
pycalibration
Manage
Activity
Members
Labels
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Model registry
Analyze
Contributor analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
calibration
pycalibration
Commits
54f8fc83
Commit
54f8fc83
authored
1 year ago
by
Thomas Kluyver
Browse files
Options
Downloads
Patches
Plain Diff
Fix select_modules()
parent
9847e518
Loading
Loading
1 merge request
!885
Revised CalCat API
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/cal_tools/calcat_interface2.py
+48
-40
48 additions, 40 deletions
src/cal_tools/calcat_interface2.py
tests/test_calcat_interface2.py
+2
-2
2 additions, 2 deletions
tests/test_calcat_interface2.py
with
50 additions
and
42 deletions
src/cal_tools/calcat_interface2.py
+
48
−
40
View file @
54f8fc83
...
...
@@ -226,10 +226,11 @@ class ModulesConstantVersions:
module_details
:
List
[
Dict
]
def
select_modules
(
self
,
module_nums
=
None
,
*
,
aggregators
=
None
,
qm_names
=
None
self
,
module_nums
=
None
,
*
,
aggregator
_name
s
=
None
,
qm_names
=
None
)
->
"
ModulesConstantVersions
"
:
aggs
=
aggregator_names
# Shorter name -> fewer multi-line statements
n_specified
=
sum
(
[
module_nums
is
not
None
,
agg
regator
s
is
not
None
,
qm_names
is
not
None
]
[
module_nums
is
not
None
,
aggs
is
not
None
,
qm_names
is
not
None
]
)
if
n_specified
<
1
:
raise
TypeError
(
"
select_modules() requires an argument
"
)
...
...
@@ -240,22 +241,23 @@ class ModulesConstantVersions:
if
module_nums
is
not
None
:
by_mod_no
=
{
m
[
"
module_number
"
]:
m
for
m
in
self
.
module_details
}
agg
regator
s
=
[
by_mod_no
[
n
][
"
karabo_da
"
]
for
n
in
module_nums
]
aggs
=
[
by_mod_no
[
n
][
"
karabo_da
"
]
for
n
in
module_nums
]
elif
qm_names
is
not
None
:
by_qm
=
{
m
[
"
virtual_device_name
"
]:
m
for
m
in
self
.
module_details
}
agg
regator
s
=
[
by_qm
[
s
][
"
karabo_da
"
]
for
s
in
qm_names
]
elif
agg
regator
s
is
not
None
:
miss
=
set
(
agg
regator
s
)
-
{
m
[
"
karabo_da
"
]
for
m
in
self
.
module_details
}
aggs
=
[
by_qm
[
s
][
"
karabo_da
"
]
for
s
in
qm_names
]
elif
aggs
is
not
None
:
miss
=
set
(
aggs
)
-
{
m
[
"
karabo_da
"
]
for
m
in
self
.
module_details
}
if
miss
:
raise
KeyError
(
"
Aggregators not found:
"
+
"
,
"
.
join
(
sorted
(
miss
)))
d
=
{
aggr
:
scv
for
(
aggr
,
scv
)
in
self
.
constants
.
items
()
if
aggr
in
aggregators
}
return
ModulesConstantVersions
(
d
,
self
.
module_details
)
d
=
{
aggr
:
scv
for
(
aggr
,
scv
)
in
self
.
constants
.
items
()
if
aggr
in
aggs
}
mods
=
[
m
for
m
in
self
.
module_details
if
m
[
"
karabo_da
"
]
in
d
]
return
ModulesConstantVersions
(
d
,
mods
)
# These properties label only the modules we have constants for, which may
# be a subset of what's in module_details
@property
def
aggregators
(
self
):
def
aggregator
_name
s
(
self
):
return
sorted
(
self
.
constants
)
@property
...
...
@@ -275,10 +277,10 @@ class ModulesConstantVersions:
]
def
ndarray
(
self
,
caldb_root
=
None
):
eg_dset
=
self
.
constants
[
self
.
aggregators
[
0
]].
dataset_obj
(
caldb_root
)
eg_dset
=
self
.
constants
[
self
.
aggregator
_name
s
[
0
]].
dataset_obj
(
caldb_root
)
shape
=
(
len
(
self
.
constants
),)
+
eg_dset
.
shape
arr
=
np
.
zeros
(
shape
,
eg_dset
.
dtype
)
for
i
,
agg
in
enumerate
(
self
.
aggregators
):
for
i
,
agg
in
enumerate
(
self
.
aggregator
_name
s
):
dset
=
self
.
constants
[
agg
].
dataset_obj
(
caldb_root
)
dset
.
read_direct
(
arr
[
i
])
return
arr
...
...
@@ -287,7 +289,7 @@ class ModulesConstantVersions:
import
xarray
if
module_naming
==
"
da
"
:
modules
=
self
.
aggregators
modules
=
self
.
aggregator
_name
s
elif
module_naming
==
"
modno
"
:
modules
=
self
.
module_nums
elif
module_naming
==
"
qm
"
:
...
...
@@ -300,7 +302,7 @@ class ModulesConstantVersions:
# Dimension labels
dims
=
[
"
module
"
]
+
[
"
dim_%d
"
%
i
for
i
in
range
(
ndarr
.
ndim
-
1
)]
coords
=
{
"
module
"
:
modules
}
name
=
self
.
constants
[
self
.
aggregators
[
0
]].
constant_name
name
=
self
.
constants
[
self
.
aggregator
_name
s
[
0
]].
constant_name
return
xarray
.
DataArray
(
ndarr
,
dims
=
dims
,
coords
=
coords
,
name
=
name
)
...
...
@@ -340,10 +342,7 @@ class CalibrationData(Mapping):
"""
Collected constants for a given detector
"""
def
__init__
(
self
,
constant_groups
,
module_details
):
self
.
constant_groups
=
{
const_type
:
ModulesConstantVersions
(
d
,
module_details
)
for
const_type
,
d
in
constant_groups
.
items
()
}
self
.
constant_groups
=
constant_groups
self
.
module_details
=
module_details
@staticmethod
...
...
@@ -398,7 +397,7 @@ class CalibrationData(Mapping):
if
mod
.
get
(
"
module_number
"
,
-
1
)
<
0
:
mod
[
"
module_number
"
]
=
int
(
re
.
findall
(
r
"
\d+
"
,
mod
[
"
karabo_da
"
])[
-
1
])
d
=
{}
constant_groups
=
{}
for
params
,
cal_types
in
cal_types_by_params_used
.
items
():
condition_dict
=
condition
.
make_dict
(
params
)
...
...
@@ -424,11 +423,14 @@ class CalibrationData(Mapping):
aggr
=
ccv
[
"
physical_detector_unit
"
][
"
karabo_da
"
]
cal_type
=
cal_id_map
[
ccv
[
"
calibration_constant
"
][
"
calibration_id
"
]]
d
.
setdefault
(
cal_type
,
{})[
aggr
]
=
SingleConstantVersion
.
from_response
(
ccv
)
const_group
=
constant_groups
.
setdefault
(
cal_type
,
{})
const_group
[
aggr
]
=
SingleConstantVersion
.
from_response
(
ccv
)
return
cls
(
d
,
module_details
)
mcvs
=
{
const_type
:
ModulesConstantVersions
(
d
,
module_details
)
for
const_type
,
d
in
constant_groups
.
items
()
}
return
cls
(
mcvs
,
module_details
)
@classmethod
def
from_report
(
...
...
@@ -447,16 +449,22 @@ class CalibrationData(Mapping):
res
=
client
.
get
(
"
calibration_constant_versions
"
,
params
)
d
=
{}
constant_groups
=
{}
pdus
=
[]
for
ccv
in
res
:
pdus
.
append
(
ccv
[
"
physical_detector_unit
"
])
cal_type
=
calibration_name
(
ccv
[
"
calibration_constant
"
][
"
calibration_id
"
])
aggr
=
ccv
[
"
physical_detector_unit
"
][
"
karabo_da
"
]
d
.
setdefault
(
cal_type
,
{})[
aggr
]
=
SingleConstantVersion
.
from_response
(
ccv
)
const_group
=
constant_groups
.
setdefault
(
cal_type
,
{})
const_group
[
aggr
]
=
SingleConstantVersion
.
from_response
(
ccv
)
return
cls
(
d
,
sorted
(
pdus
,
key
=
lambda
d
:
d
[
"
karabo_da
"
]))
module_details
=
sorted
(
pdus
,
key
=
lambda
d
:
d
[
"
karabo_da
"
])
mcvs
=
{
const_type
:
ModulesConstantVersions
(
d
,
module_details
)
for
const_type
,
d
in
constant_groups
.
items
()
}
return
cls
(
mcvs
,
module_details
)
def
__getitem__
(
self
,
key
)
->
ModulesConstantVersions
:
return
self
.
constant_groups
[
key
]
...
...
@@ -494,7 +502,7 @@ class CalibrationData(Mapping):
def
require_calibrations
(
self
,
calibrations
):
"""
Drop any modules missing the specified constant types
"""
mods
=
set
(
self
.
aggregators
)
mods
=
set
(
self
.
aggregator
_name
s
)
for
cal_type
in
calibrations
:
mods
.
intersection_update
(
self
[
cal_type
].
constants
)
return
self
.
select_modules
(
mods
)
...
...
@@ -514,19 +522,19 @@ class CalibrationData(Mapping):
return
type
(
self
)(
d
,
self
.
aggregators
)
def
select_modules
(
self
,
module_nums
=
None
,
*
,
aggregators
=
None
,
qm_names
=
None
self
,
module_nums
=
None
,
*
,
aggregator
_name
s
=
None
,
qm_names
=
None
)
->
"
CalibrationData
"
:
return
type
(
self
)(
{
cal_type
:
mcv
.
select_modules
(
module_nums
=
module_num
s
,
aggregators
=
aggregator
s
,
qm_names
=
qm_names
,
).
constants
for
(
cal_type
,
mcv
)
in
self
.
constant_groups
.
items
()
},
sorted
(
aggregators
),
)
mcvs
=
{
cal_type
:
mcv
.
select_modules
(
module_nums
=
module_nums
,
aggregator_names
=
aggregator_name
s
,
qm_names
=
qm_name
s
,
)
for
(
cal_type
,
mcv
)
in
self
.
constant_groups
.
items
()
}
aggs
=
set
().
union
(
*
[
c
.
aggregator_names
for
c
in
mcvs
.
values
()])
module_details
=
[
m
for
m
in
self
.
module_details
if
m
[
"
karabo_da
"
]
in
aggs
]
return
type
(
self
)(
mcvs
,
module_details
)
def
merge
(
self
,
*
others
:
"
CalibrationData
"
)
->
"
CalibrationData
"
:
d
=
{}
...
...
@@ -536,9 +544,9 @@ class CalibrationData(Mapping):
if
cal_type
in
other
:
d
[
cal_type
].
update
(
other
[
cal_type
].
constants
)
aggregators
=
set
(
self
.
aggregators
)
aggregators
=
set
(
self
.
aggregator
_name
s
)
for
other
in
others
:
aggregators
.
update
(
other
.
aggregators
)
aggregators
.
update
(
other
.
aggregator
_name
s
)
return
type
(
self
)(
d
,
sorted
(
aggregators
))
...
...
This diff is collapsed.
Click to expand it.
tests/test_calcat_interface2.py
+
2
−
2
View file @
54f8fc83
...
...
@@ -80,12 +80,12 @@ def test_DSSC_modules_missing():
aggs_q3
=
[
f
"
DSSC
{
m
:
02
}
"
for
m
in
modnos_q3
]
qm_q3
=
[
f
"
Q3M
{
i
}
"
for
i
in
range
(
1
,
5
)]
assert
offset
.
select_modules
(
modnos_q3
).
module_nums
==
modnos_q3
assert
offset
.
select_modules
(
aggregators
=
aggs_q3
).
module_nums
==
modnos_q3
assert
offset
.
select_modules
(
aggregator
_name
s
=
aggs_q3
).
module_nums
==
modnos_q3
assert
offset
.
select_modules
(
qm_names
=
qm_q3
).
module_nums
==
modnos_q3
# test CalibrationData.select_modules()
assert
dssc_cd
.
select_modules
(
modnos_q3
).
module_nums
==
modnos_q3
assert
dssc_cd
.
select_modules
(
aggregators
=
aggs_q3
).
module_nums
==
modnos_q3
assert
dssc_cd
.
select_modules
(
aggregator
_name
s
=
aggs_q3
).
module_nums
==
modnos_q3
assert
dssc_cd
.
select_modules
(
qm_names
=
qm_q3
).
module_nums
==
modnos_q3
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment