Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
calng
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
calibration
calng
Commits
02a4ccd9
Commit
02a4ccd9
authored
2 years ago
by
David Hammer
Browse files
Options
Downloads
Patches
Plain Diff
Allow stacking on other axes, fix thread pool
parent
b2b5aca1
No related branches found
Branches containing commit
No related tags found
2 merge requests
!10
DetectorAssembler: assemble with extra shape (multiple frames)
,
!9
Stacking shmem matcher
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/calng/ShmemTrainMatcher.py
+70
-33
70 additions, 33 deletions
src/calng/ShmemTrainMatcher.py
src/calng/utils.py
+25
-0
25 additions, 0 deletions
src/calng/utils.py
with
95 additions
and
33 deletions
src/calng/ShmemTrainMatcher.py
+
70
−
33
View file @
02a4ccd9
...
@@ -5,6 +5,7 @@ import re
...
@@ -5,6 +5,7 @@ import re
import
numpy
as
np
import
numpy
as
np
from
karabo.bound
import
(
from
karabo.bound
import
(
BOOL_ELEMENT
,
BOOL_ELEMENT
,
INT32_ELEMENT
,
KARABO_CLASSINFO
,
KARABO_CLASSINFO
,
STRING_ELEMENT
,
STRING_ELEMENT
,
TABLE_ELEMENT
,
TABLE_ELEMENT
,
...
@@ -14,7 +15,7 @@ from karabo.bound import (
...
@@ -14,7 +15,7 @@ from karabo.bound import (
)
)
from
TrainMatcher
import
TrainMatcher
from
TrainMatcher
import
TrainMatcher
from
.
import
shmem_utils
from
.
import
shmem_utils
,
utils
from
._version
import
version
as
deviceVersion
from
._version
import
version
as
deviceVersion
...
@@ -35,7 +36,7 @@ def merge_schema():
...
@@ -35,7 +36,7 @@ def merge_schema():
.
commit
(),
.
commit
(),
STRING_ELEMENT
(
schema
)
STRING_ELEMENT
(
schema
)
.
key
(
"
source
_p
attern
"
)
.
key
(
"
source
P
attern
"
)
.
displayedName
(
"
Source pattern
"
)
.
displayedName
(
"
Source pattern
"
)
.
assignmentOptional
()
.
assignmentOptional
()
.
defaultValue
(
""
)
.
defaultValue
(
""
)
...
@@ -43,7 +44,7 @@ def merge_schema():
...
@@ -43,7 +44,7 @@ def merge_schema():
.
commit
(),
.
commit
(),
STRING_ELEMENT
(
schema
)
STRING_ELEMENT
(
schema
)
.
key
(
"
key
_p
attern
"
)
.
key
(
"
key
P
attern
"
)
.
displayedName
(
"
Key pattern
"
)
.
displayedName
(
"
Key pattern
"
)
.
assignmentOptional
()
.
assignmentOptional
()
.
defaultValue
(
""
)
.
defaultValue
(
""
)
...
@@ -66,6 +67,14 @@ def merge_schema():
...
@@ -66,6 +67,14 @@ def merge_schema():
.
defaultValue
(
MergeGroupType
.
MULTISOURCE
.
value
)
.
defaultValue
(
MergeGroupType
.
MULTISOURCE
.
value
)
.
reconfigurable
()
.
reconfigurable
()
.
commit
(),
.
commit
(),
INT32_ELEMENT
(
schema
)
.
key
(
"
stackingAxis
"
)
.
displayedName
(
"
Axis
"
)
.
assignmentOptional
()
.
defaultValue
(
0
)
.
reconfigurable
()
.
commit
(),
)
)
return
schema
return
schema
...
@@ -78,7 +87,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -78,7 +87,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
(
(
TABLE_ELEMENT
(
expected
)
TABLE_ELEMENT
(
expected
)
.
key
(
"
merge
"
)
.
key
(
"
merge
"
)
.
displayedName
(
"
Array
merg
ing
"
)
.
displayedName
(
"
Array
stack
ing
"
)
.
allowedStates
(
State
.
PASSIVE
)
.
allowedStates
(
State
.
PASSIVE
)
.
description
(
.
description
(
"
List source or key patterns to merge their data arrays, e.g. to
"
"
List source or key patterns to merge their data arrays, e.g. to
"
...
@@ -95,8 +104,11 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -95,8 +104,11 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
BOOL_ELEMENT
(
expected
)
BOOL_ELEMENT
(
expected
)
.
key
(
"
useThreadPool
"
)
.
key
(
"
useThreadPool
"
)
.
displayedName
(
"
Use thread pool
"
)
.
allowedStates
(
State
.
PASSIVE
)
.
assignmentOptional
()
.
assignmentOptional
()
.
defaultValue
(
False
)
.
defaultValue
(
False
)
.
reconfigurable
()
.
commit
(),
.
commit
(),
)
)
...
@@ -117,6 +129,12 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -117,6 +129,12 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
super
().
preReconfigure
(
conf
)
super
().
preReconfigure
(
conf
)
if
conf
.
has
(
"
merge
"
)
or
conf
.
has
(
"
sources
"
):
if
conf
.
has
(
"
merge
"
)
or
conf
.
has
(
"
sources
"
):
self
.
_prepare_merge_groups
(
conf
[
"
merge
"
])
self
.
_prepare_merge_groups
(
conf
[
"
merge
"
])
if
conf
.
has
(
"
useThreadPool
"
):
if
self
.
_thread_pool
is
not
None
:
self
.
_thread_pool
.
shutdown
()
self
.
_thread_pool
=
None
if
conf
[
"
useThreadPool
"
]:
self
.
_thread_pool
=
concurrent
.
futures
.
ThreadPoolExecutor
()
def
_prepare_merge_groups
(
self
,
merge
):
def
_prepare_merge_groups
(
self
,
merge
):
source_group_patterns
=
[]
source_group_patterns
=
[]
...
@@ -129,16 +147,17 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -129,16 +147,17 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
if
group_type
is
MergeGroupType
.
MULTISOURCE
:
if
group_type
is
MergeGroupType
.
MULTISOURCE
:
source_group_patterns
.
append
(
source_group_patterns
.
append
(
(
(
re
.
compile
(
row
[
"
source
_p
attern
"
]),
re
.
compile
(
row
[
"
source
P
attern
"
]),
row
[
"
key
_p
attern
"
],
row
[
"
key
P
attern
"
],
row
[
"
replacement
"
],
row
[
"
replacement
"
],
row
[
"
stackingAxis
"
],
)
)
)
)
else
:
else
:
key_group_patterns
.
append
(
key_group_patterns
.
append
(
(
(
re
.
compile
(
row
[
"
source
_p
attern
"
]),
re
.
compile
(
row
[
"
source
P
attern
"
]),
re
.
compile
(
row
[
"
key
_p
attern
"
]),
re
.
compile
(
row
[
"
key
P
attern
"
]),
row
[
"
replacement
"
],
row
[
"
replacement
"
],
)
)
)
)
...
@@ -150,7 +169,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -150,7 +169,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
# handle source stacking groups
# handle source stacking groups
self
.
_source_stacking_indices
.
clear
()
self
.
_source_stacking_indices
.
clear
()
self
.
_source_stacking_sources
.
clear
()
self
.
_source_stacking_sources
.
clear
()
for
source_re
,
key
,
new_source
in
source_group_patterns
:
for
source_re
,
key
,
new_source
,
stack_axis
in
source_group_patterns
:
merge_sources
=
[
merge_sources
=
[
source
for
source
in
source_names
if
source_re
.
match
(
source
)
source
for
source
in
source_names
if
source_re
.
match
(
source
)
]
]
...
@@ -161,9 +180,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -161,9 +180,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
continue
continue
for
(
i
,
source
)
in
enumerate
(
merge_sources
):
for
(
i
,
source
)
in
enumerate
(
merge_sources
):
self
.
_source_stacking_sources
.
setdefault
(
source
,
[]).
append
(
self
.
_source_stacking_sources
.
setdefault
(
source
,
[]).
append
(
(
key
,
new_source
)
(
key
,
new_source
,
stack_axis
)
)
)
self
.
_source_stacking_indices
[(
source
,
key
)]
=
i
self
.
_source_stacking_indices
[(
source
,
new_source
,
key
)]
=
i
# handle key stacking groups
# handle key stacking groups
self
.
_key_stacking_sources
.
clear
()
self
.
_key_stacking_sources
.
clear
()
...
@@ -176,43 +195,61 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -176,43 +195,61 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
(
key_re
,
new_key
)
(
key_re
,
new_key
)
)
)
def
_update_stacking_buffer
(
self
,
new_source
,
key
,
individual_shape
,
axis
,
dtype
):
# TODO: handle ValueError for max of empty sequence
stack_num
=
(
max
(
index
for
(
_
,
new_source_
,
key_
,
),
index
in
self
.
_source_stacking_indices
.
items
()
if
new_source_
==
new_source
and
key_
==
key
)
+
1
)
self
.
_stacking_buffers
[(
new_source
,
key
)]
=
np
.
empty
(
shape
=
utils
.
stacking_buffer_shape
(
individual_shape
,
stack_num
,
axis
=
axis
),
dtype
=
dtype
,
)
def
_handle_source
(
self
,
source
,
data_hash
,
timestamp
,
new_sources_map
):
def
_handle_source
(
self
,
source
,
data_hash
,
timestamp
,
new_sources_map
):
# dereference calng shmem handles
# dereference calng shmem handles
self
.
_shmem_handler
.
dereference_shmem_handles
(
data_hash
)
self
.
_shmem_handler
.
dereference_shmem_handles
(
data_hash
)
# stack across sources (many sources, same key)
# stack across sources (many sources, same key)
# could probably save ~100 ns by "if ... in" instead of get
# could probably save ~100 ns by "if ... in" instead of get
for
(
stack_key
,
new_source
)
in
self
.
_source_stacking_sources
.
get
(
source
,
()):
for
(
stack_key
,
new_source
,
stack_axis
)
in
self
.
_source_stacking_sources
.
get
(
source
,
()
):
this_data
=
data_hash
.
get
(
stack_key
)
this_data
=
data_hash
.
get
(
stack_key
)
try
:
try
:
this_buffer
=
self
.
_stacking_buffers
[(
new_source
,
stack_key
)]
this_buffer
=
self
.
_stacking_buffers
[(
new_source
,
stack_key
)]
stack_index
=
self
.
_source_stacking_indices
[(
source
,
stack_key
)]
stack_index
=
self
.
_source_stacking_indices
[
this_buffer
[
stack_index
]
=
this_data
(
source
,
new_source
,
stack_key
)
]
utils
.
set_on_axis
(
this_buffer
,
this_data
,
stack_index
,
stack_axis
)
except
(
ValueError
,
IndexError
,
KeyError
):
except
(
ValueError
,
IndexError
,
KeyError
):
# ValueError: wrong shape
# ValueError: wrong shape
(react to this_data.shape)
# KeyError: buffer doesn't exist yet
# KeyError: buffer doesn't exist yet
# IndexError: new source? (buffer not long enough)
# either way, create appropriate buffer now
# either way, create appropriate buffer now
# TODO: complain if shape varies between sources
# TODO: complain if shape varies between sources within train
self
.
_stacking_buffers
[(
new_source
,
stack_key
)]
=
np
.
empty
(
self
.
_update_stacking_buffer
(
shape
=
(
new_source
,
max
(
stack_key
,
index_
this_data
.
shape
,
for
(
axis
=
stack_axis
,
source_
,
key_
,
),
index_
in
self
.
_source_stacking_indices
.
items
()
if
source_
==
source
and
key_
==
stack_key
)
+
1
,
)
+
this_data
.
shape
,
dtype
=
this_data
.
dtype
,
dtype
=
this_data
.
dtype
,
)
)
# and then try again
# and then try again
this_buffer
=
self
.
_stacking_buffers
[(
new_source
,
stack_key
)]
this_buffer
=
self
.
_stacking_buffers
[(
new_source
,
stack_key
)]
stack_index
=
self
.
_source_stacking_indices
[(
source
,
stack_key
)]
stack_index
=
self
.
_source_stacking_indices
[
this_buffer
[
stack_index
]
=
this_data
(
source
,
new_source
,
stack_key
)
# TODO: zero out unfilled buffer entries
]
utils
.
set_on_axis
(
this_buffer
,
this_data
,
stack_index
,
stack_axis
)
# TODO: zero out unfilled buffer entries
data_hash
.
erase
(
stack_key
)
data_hash
.
erase
(
stack_key
)
if
new_source
not
in
new_sources_map
:
if
new_source
not
in
new_sources_map
:
...
@@ -246,7 +283,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -246,7 +283,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
concurrent
.
futures
.
wait
(
concurrent
.
futures
.
wait
(
[
[
self
.
_thread_pool
.
submit
(
self
.
_thread_pool
.
submit
(
self
.
_handle_source
,
data
,
timestamp
,
new_sources_map
self
.
_handle_source
,
source
,
data
,
timestamp
,
new_sources_map
)
)
for
source
,
(
data
,
timestamp
)
in
sources
.
items
()
for
source
,
(
data
,
timestamp
)
in
sources
.
items
()
]
]
...
...
This diff is collapsed.
Click to expand it.
src/calng/utils.py
+
25
−
0
View file @
02a4ccd9
...
@@ -111,6 +111,31 @@ def transpose_order(axes_in, axes_out):
...
@@ -111,6 +111,31 @@ def transpose_order(axes_in, axes_out):
return
tuple
(
axis_order
[
axis
]
for
axis
in
axes_out
)
return
tuple
(
axis_order
[
axis
]
for
axis
in
axes_out
)
def
stacking_buffer_shape
(
array_shape
,
stack_num
,
axis
=
0
):
"""
Figures out the shape you would need for np.stack
"""
if
axis
>
len
(
array_shape
)
or
axis
<
-
len
(
array_shape
)
-
1
:
# complain when np.stack would
raise
np
.
AxisError
(
f
"
axis
{
axis
}
is out of bounds
"
f
"
for array of dimension
{
len
(
array_shape
)
+
1
}
"
)
if
axis
<
0
:
axis
+=
len
(
array_shape
)
+
1
return
array_shape
[:
axis
]
+
(
stack_num
,)
+
array_shape
[
axis
:]
def
set_on_axis
(
array
,
vals
,
index
,
axis
):
"""
set_on_axis(A, x, 1, 2) corresponds to A[:, :, 1] = x
"""
if
axis
>=
len
(
array
):
raise
IndexError
(
f
"
too many indices for array: array is
{
len
(
array
.
shape
)
}
-dimensional,
"
f
"
but
{
axis
+
1
}
were indexed
"
)
# TODO: maybe support negative axis with wraparound
indices
=
np
.
index_exp
[:]
*
axis
+
np
.
index_exp
[
index
]
array
[
indices
]
=
vals
_np_typechar_to_c_typestring
=
{
_np_typechar_to_c_typestring
=
{
"
?
"
:
"
bool
"
,
"
?
"
:
"
bool
"
,
"
B
"
:
"
unsigned char
"
,
"
B
"
:
"
unsigned char
"
,
...
...
This diff is collapsed.
Click to expand it.
David Hammer
@hammerd
mentioned in commit
9495dda8
·
2 years ago
mentioned in commit
9495dda8
mentioned in commit 9495dda895756fcd9a2ebb5f8adbaa050e067f33
Toggle commit list
David Hammer
@hammerd
mentioned in commit
99de26dd
·
2 years ago
mentioned in commit
99de26dd
mentioned in commit 99de26dda08d02c1c2b5efc1f1f462cfc4f0ce63
Toggle commit list
David Hammer
@hammerd
mentioned in commit
3566342d
·
2 years ago
mentioned in commit
3566342d
mentioned in commit 3566342d031db1bb435b6430f100afdbaaa6c4bf
Toggle commit list
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment