Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
pycalibration
Manage
Activity
Members
Labels
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Deploy
Model registry
Analyze
Contributor analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
calibration
pycalibration
Commits
54f8fc83
Commit
54f8fc83
authored
1 year ago
by
Thomas Kluyver
Browse files
Options
Downloads
Patches
Plain Diff
Fix select_modules()
parent
9847e518
No related branches found
No related tags found
1 merge request
!885
Revised CalCat API
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/cal_tools/calcat_interface2.py
+48
-40
48 additions, 40 deletions
src/cal_tools/calcat_interface2.py
tests/test_calcat_interface2.py
+2
-2
2 additions, 2 deletions
tests/test_calcat_interface2.py
with
50 additions
and
42 deletions
src/cal_tools/calcat_interface2.py
+
48
−
40
View file @
54f8fc83
...
@@ -226,10 +226,11 @@ class ModulesConstantVersions:
...
@@ -226,10 +226,11 @@ class ModulesConstantVersions:
module_details
:
List
[
Dict
]
module_details
:
List
[
Dict
]
def
select_modules
(
def
select_modules
(
self
,
module_nums
=
None
,
*
,
aggregators
=
None
,
qm_names
=
None
self
,
module_nums
=
None
,
*
,
aggregator
_name
s
=
None
,
qm_names
=
None
)
->
"
ModulesConstantVersions
"
:
)
->
"
ModulesConstantVersions
"
:
aggs
=
aggregator_names
# Shorter name -> fewer multi-line statements
n_specified
=
sum
(
n_specified
=
sum
(
[
module_nums
is
not
None
,
agg
regator
s
is
not
None
,
qm_names
is
not
None
]
[
module_nums
is
not
None
,
aggs
is
not
None
,
qm_names
is
not
None
]
)
)
if
n_specified
<
1
:
if
n_specified
<
1
:
raise
TypeError
(
"
select_modules() requires an argument
"
)
raise
TypeError
(
"
select_modules() requires an argument
"
)
...
@@ -240,22 +241,23 @@ class ModulesConstantVersions:
...
@@ -240,22 +241,23 @@ class ModulesConstantVersions:
if
module_nums
is
not
None
:
if
module_nums
is
not
None
:
by_mod_no
=
{
m
[
"
module_number
"
]:
m
for
m
in
self
.
module_details
}
by_mod_no
=
{
m
[
"
module_number
"
]:
m
for
m
in
self
.
module_details
}
agg
regator
s
=
[
by_mod_no
[
n
][
"
karabo_da
"
]
for
n
in
module_nums
]
aggs
=
[
by_mod_no
[
n
][
"
karabo_da
"
]
for
n
in
module_nums
]
elif
qm_names
is
not
None
:
elif
qm_names
is
not
None
:
by_qm
=
{
m
[
"
virtual_device_name
"
]:
m
for
m
in
self
.
module_details
}
by_qm
=
{
m
[
"
virtual_device_name
"
]:
m
for
m
in
self
.
module_details
}
agg
regator
s
=
[
by_qm
[
s
][
"
karabo_da
"
]
for
s
in
qm_names
]
aggs
=
[
by_qm
[
s
][
"
karabo_da
"
]
for
s
in
qm_names
]
elif
agg
regator
s
is
not
None
:
elif
aggs
is
not
None
:
miss
=
set
(
agg
regator
s
)
-
{
m
[
"
karabo_da
"
]
for
m
in
self
.
module_details
}
miss
=
set
(
aggs
)
-
{
m
[
"
karabo_da
"
]
for
m
in
self
.
module_details
}
if
miss
:
if
miss
:
raise
KeyError
(
"
Aggregators not found:
"
+
"
,
"
.
join
(
sorted
(
miss
)))
raise
KeyError
(
"
Aggregators not found:
"
+
"
,
"
.
join
(
sorted
(
miss
)))
d
=
{
aggr
:
scv
for
(
aggr
,
scv
)
in
self
.
constants
.
items
()
if
aggr
in
aggregators
}
d
=
{
aggr
:
scv
for
(
aggr
,
scv
)
in
self
.
constants
.
items
()
if
aggr
in
aggs
}
return
ModulesConstantVersions
(
d
,
self
.
module_details
)
mods
=
[
m
for
m
in
self
.
module_details
if
m
[
"
karabo_da
"
]
in
d
]
return
ModulesConstantVersions
(
d
,
mods
)
# These properties label only the modules we have constants for, which may
# These properties label only the modules we have constants for, which may
# be a subset of what's in module_details
# be a subset of what's in module_details
@property
@property
def
aggregators
(
self
):
def
aggregator
_name
s
(
self
):
return
sorted
(
self
.
constants
)
return
sorted
(
self
.
constants
)
@property
@property
...
@@ -275,10 +277,10 @@ class ModulesConstantVersions:
...
@@ -275,10 +277,10 @@ class ModulesConstantVersions:
]
]
def
ndarray
(
self
,
caldb_root
=
None
):
def
ndarray
(
self
,
caldb_root
=
None
):
eg_dset
=
self
.
constants
[
self
.
aggregators
[
0
]].
dataset_obj
(
caldb_root
)
eg_dset
=
self
.
constants
[
self
.
aggregator
_name
s
[
0
]].
dataset_obj
(
caldb_root
)
shape
=
(
len
(
self
.
constants
),)
+
eg_dset
.
shape
shape
=
(
len
(
self
.
constants
),)
+
eg_dset
.
shape
arr
=
np
.
zeros
(
shape
,
eg_dset
.
dtype
)
arr
=
np
.
zeros
(
shape
,
eg_dset
.
dtype
)
for
i
,
agg
in
enumerate
(
self
.
aggregators
):
for
i
,
agg
in
enumerate
(
self
.
aggregator
_name
s
):
dset
=
self
.
constants
[
agg
].
dataset_obj
(
caldb_root
)
dset
=
self
.
constants
[
agg
].
dataset_obj
(
caldb_root
)
dset
.
read_direct
(
arr
[
i
])
dset
.
read_direct
(
arr
[
i
])
return
arr
return
arr
...
@@ -287,7 +289,7 @@ class ModulesConstantVersions:
...
@@ -287,7 +289,7 @@ class ModulesConstantVersions:
import
xarray
import
xarray
if
module_naming
==
"
da
"
:
if
module_naming
==
"
da
"
:
modules
=
self
.
aggregators
modules
=
self
.
aggregator
_name
s
elif
module_naming
==
"
modno
"
:
elif
module_naming
==
"
modno
"
:
modules
=
self
.
module_nums
modules
=
self
.
module_nums
elif
module_naming
==
"
qm
"
:
elif
module_naming
==
"
qm
"
:
...
@@ -300,7 +302,7 @@ class ModulesConstantVersions:
...
@@ -300,7 +302,7 @@ class ModulesConstantVersions:
# Dimension labels
# Dimension labels
dims
=
[
"
module
"
]
+
[
"
dim_%d
"
%
i
for
i
in
range
(
ndarr
.
ndim
-
1
)]
dims
=
[
"
module
"
]
+
[
"
dim_%d
"
%
i
for
i
in
range
(
ndarr
.
ndim
-
1
)]
coords
=
{
"
module
"
:
modules
}
coords
=
{
"
module
"
:
modules
}
name
=
self
.
constants
[
self
.
aggregators
[
0
]].
constant_name
name
=
self
.
constants
[
self
.
aggregator
_name
s
[
0
]].
constant_name
return
xarray
.
DataArray
(
ndarr
,
dims
=
dims
,
coords
=
coords
,
name
=
name
)
return
xarray
.
DataArray
(
ndarr
,
dims
=
dims
,
coords
=
coords
,
name
=
name
)
...
@@ -340,10 +342,7 @@ class CalibrationData(Mapping):
...
@@ -340,10 +342,7 @@ class CalibrationData(Mapping):
"""
Collected constants for a given detector
"""
"""
Collected constants for a given detector
"""
def
__init__
(
self
,
constant_groups
,
module_details
):
def
__init__
(
self
,
constant_groups
,
module_details
):
self
.
constant_groups
=
{
self
.
constant_groups
=
constant_groups
const_type
:
ModulesConstantVersions
(
d
,
module_details
)
for
const_type
,
d
in
constant_groups
.
items
()
}
self
.
module_details
=
module_details
self
.
module_details
=
module_details
@staticmethod
@staticmethod
...
@@ -398,7 +397,7 @@ class CalibrationData(Mapping):
...
@@ -398,7 +397,7 @@ class CalibrationData(Mapping):
if
mod
.
get
(
"
module_number
"
,
-
1
)
<
0
:
if
mod
.
get
(
"
module_number
"
,
-
1
)
<
0
:
mod
[
"
module_number
"
]
=
int
(
re
.
findall
(
r
"
\d+
"
,
mod
[
"
karabo_da
"
])[
-
1
])
mod
[
"
module_number
"
]
=
int
(
re
.
findall
(
r
"
\d+
"
,
mod
[
"
karabo_da
"
])[
-
1
])
d
=
{}
constant_groups
=
{}
for
params
,
cal_types
in
cal_types_by_params_used
.
items
():
for
params
,
cal_types
in
cal_types_by_params_used
.
items
():
condition_dict
=
condition
.
make_dict
(
params
)
condition_dict
=
condition
.
make_dict
(
params
)
...
@@ -424,11 +423,14 @@ class CalibrationData(Mapping):
...
@@ -424,11 +423,14 @@ class CalibrationData(Mapping):
aggr
=
ccv
[
"
physical_detector_unit
"
][
"
karabo_da
"
]
aggr
=
ccv
[
"
physical_detector_unit
"
][
"
karabo_da
"
]
cal_type
=
cal_id_map
[
ccv
[
"
calibration_constant
"
][
"
calibration_id
"
]]
cal_type
=
cal_id_map
[
ccv
[
"
calibration_constant
"
][
"
calibration_id
"
]]
d
.
setdefault
(
cal_type
,
{})[
aggr
]
=
SingleConstantVersion
.
from_response
(
const_group
=
constant_groups
.
setdefault
(
cal_type
,
{})
ccv
const_group
[
aggr
]
=
SingleConstantVersion
.
from_response
(
ccv
)
)
return
cls
(
d
,
module_details
)
mcvs
=
{
const_type
:
ModulesConstantVersions
(
d
,
module_details
)
for
const_type
,
d
in
constant_groups
.
items
()
}
return
cls
(
mcvs
,
module_details
)
@classmethod
@classmethod
def
from_report
(
def
from_report
(
...
@@ -447,16 +449,22 @@ class CalibrationData(Mapping):
...
@@ -447,16 +449,22 @@ class CalibrationData(Mapping):
res
=
client
.
get
(
"
calibration_constant_versions
"
,
params
)
res
=
client
.
get
(
"
calibration_constant_versions
"
,
params
)
d
=
{}
constant_groups
=
{}
pdus
=
[]
pdus
=
[]
for
ccv
in
res
:
for
ccv
in
res
:
pdus
.
append
(
ccv
[
"
physical_detector_unit
"
])
pdus
.
append
(
ccv
[
"
physical_detector_unit
"
])
cal_type
=
calibration_name
(
ccv
[
"
calibration_constant
"
][
"
calibration_id
"
])
cal_type
=
calibration_name
(
ccv
[
"
calibration_constant
"
][
"
calibration_id
"
])
aggr
=
ccv
[
"
physical_detector_unit
"
][
"
karabo_da
"
]
aggr
=
ccv
[
"
physical_detector_unit
"
][
"
karabo_da
"
]
d
.
setdefault
(
cal_type
,
{})[
aggr
]
=
SingleConstantVersion
.
from_response
(
ccv
)
const_group
=
constant_groups
.
setdefault
(
cal_type
,
{})
const_group
[
aggr
]
=
SingleConstantVersion
.
from_response
(
ccv
)
return
cls
(
d
,
sorted
(
pdus
,
key
=
lambda
d
:
d
[
"
karabo_da
"
]))
module_details
=
sorted
(
pdus
,
key
=
lambda
d
:
d
[
"
karabo_da
"
])
mcvs
=
{
const_type
:
ModulesConstantVersions
(
d
,
module_details
)
for
const_type
,
d
in
constant_groups
.
items
()
}
return
cls
(
mcvs
,
module_details
)
def
__getitem__
(
self
,
key
)
->
ModulesConstantVersions
:
def
__getitem__
(
self
,
key
)
->
ModulesConstantVersions
:
return
self
.
constant_groups
[
key
]
return
self
.
constant_groups
[
key
]
...
@@ -494,7 +502,7 @@ class CalibrationData(Mapping):
...
@@ -494,7 +502,7 @@ class CalibrationData(Mapping):
def
require_calibrations
(
self
,
calibrations
):
def
require_calibrations
(
self
,
calibrations
):
"""
Drop any modules missing the specified constant types
"""
"""
Drop any modules missing the specified constant types
"""
mods
=
set
(
self
.
aggregators
)
mods
=
set
(
self
.
aggregator
_name
s
)
for
cal_type
in
calibrations
:
for
cal_type
in
calibrations
:
mods
.
intersection_update
(
self
[
cal_type
].
constants
)
mods
.
intersection_update
(
self
[
cal_type
].
constants
)
return
self
.
select_modules
(
mods
)
return
self
.
select_modules
(
mods
)
...
@@ -514,19 +522,19 @@ class CalibrationData(Mapping):
...
@@ -514,19 +522,19 @@ class CalibrationData(Mapping):
return
type
(
self
)(
d
,
self
.
aggregators
)
return
type
(
self
)(
d
,
self
.
aggregators
)
def
select_modules
(
def
select_modules
(
self
,
module_nums
=
None
,
*
,
aggregators
=
None
,
qm_names
=
None
self
,
module_nums
=
None
,
*
,
aggregator
_name
s
=
None
,
qm_names
=
None
)
->
"
CalibrationData
"
:
)
->
"
CalibrationData
"
:
return
type
(
self
)(
mcvs
=
{
{
cal_type
:
mcv
.
select_modules
(
cal_type
:
mcv
.
select_modules
(
module_nums
=
module_nums
,
module_nums
=
module_num
s
,
aggregator_names
=
aggregator_name
s
,
aggregators
=
aggregator
s
,
qm_names
=
qm_name
s
,
qm_names
=
qm_names
,
)
).
constants
for
(
cal_type
,
mcv
)
in
self
.
constant_groups
.
items
()
for
(
cal_type
,
mcv
)
in
self
.
constant_groups
.
items
()
}
},
aggs
=
set
().
union
(
*
[
c
.
aggregator_names
for
c
in
mcvs
.
values
()])
sorted
(
aggregators
),
module_details
=
[
m
for
m
in
self
.
module_details
if
m
[
"
karabo_da
"
]
in
aggs
]
)
return
type
(
self
)(
mcvs
,
module_details
)
def
merge
(
self
,
*
others
:
"
CalibrationData
"
)
->
"
CalibrationData
"
:
def
merge
(
self
,
*
others
:
"
CalibrationData
"
)
->
"
CalibrationData
"
:
d
=
{}
d
=
{}
...
@@ -536,9 +544,9 @@ class CalibrationData(Mapping):
...
@@ -536,9 +544,9 @@ class CalibrationData(Mapping):
if
cal_type
in
other
:
if
cal_type
in
other
:
d
[
cal_type
].
update
(
other
[
cal_type
].
constants
)
d
[
cal_type
].
update
(
other
[
cal_type
].
constants
)
aggregators
=
set
(
self
.
aggregators
)
aggregators
=
set
(
self
.
aggregator
_name
s
)
for
other
in
others
:
for
other
in
others
:
aggregators
.
update
(
other
.
aggregators
)
aggregators
.
update
(
other
.
aggregator
_name
s
)
return
type
(
self
)(
d
,
sorted
(
aggregators
))
return
type
(
self
)(
d
,
sorted
(
aggregators
))
...
...
This diff is collapsed.
Click to expand it.
tests/test_calcat_interface2.py
+
2
−
2
View file @
54f8fc83
...
@@ -80,12 +80,12 @@ def test_DSSC_modules_missing():
...
@@ -80,12 +80,12 @@ def test_DSSC_modules_missing():
aggs_q3
=
[
f
"
DSSC
{
m
:
02
}
"
for
m
in
modnos_q3
]
aggs_q3
=
[
f
"
DSSC
{
m
:
02
}
"
for
m
in
modnos_q3
]
qm_q3
=
[
f
"
Q3M
{
i
}
"
for
i
in
range
(
1
,
5
)]
qm_q3
=
[
f
"
Q3M
{
i
}
"
for
i
in
range
(
1
,
5
)]
assert
offset
.
select_modules
(
modnos_q3
).
module_nums
==
modnos_q3
assert
offset
.
select_modules
(
modnos_q3
).
module_nums
==
modnos_q3
assert
offset
.
select_modules
(
aggregators
=
aggs_q3
).
module_nums
==
modnos_q3
assert
offset
.
select_modules
(
aggregator
_name
s
=
aggs_q3
).
module_nums
==
modnos_q3
assert
offset
.
select_modules
(
qm_names
=
qm_q3
).
module_nums
==
modnos_q3
assert
offset
.
select_modules
(
qm_names
=
qm_q3
).
module_nums
==
modnos_q3
# test CalibrationData.select_modules()
# test CalibrationData.select_modules()
assert
dssc_cd
.
select_modules
(
modnos_q3
).
module_nums
==
modnos_q3
assert
dssc_cd
.
select_modules
(
modnos_q3
).
module_nums
==
modnos_q3
assert
dssc_cd
.
select_modules
(
aggregators
=
aggs_q3
).
module_nums
==
modnos_q3
assert
dssc_cd
.
select_modules
(
aggregator
_name
s
=
aggs_q3
).
module_nums
==
modnos_q3
assert
dssc_cd
.
select_modules
(
qm_names
=
qm_q3
).
module_nums
==
modnos_q3
assert
dssc_cd
.
select_modules
(
qm_names
=
qm_q3
).
module_nums
==
modnos_q3
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment