Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
calng
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
calibration
calng
Commits
c3e1c48f
Commit
c3e1c48f
authored
1 year ago
by
David Hammer
Browse files
Options
Downloads
Patches
Plain Diff
WIP: restructure / simplify stacking execution
parent
ca46ba88
No related branches found
No related tags found
2 merge requests
!74
Refactor DetectorAssembler
,
!73
Refactor stacking for reuse and overlappability
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/calng/ShmemTrainMatcher.py
+20
-20
20 additions, 20 deletions
src/calng/ShmemTrainMatcher.py
src/calng/stacking_utils.py
+139
-123
139 additions, 123 deletions
src/calng/stacking_utils.py
with
159 additions
and
143 deletions
src/calng/ShmemTrainMatcher.py
+
20
−
20
View file @
c3e1c48f
...
@@ -69,7 +69,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -69,7 +69,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
def
initialization
(
self
):
def
initialization
(
self
):
super
().
initialization
()
super
().
initialization
()
self
.
_shmem_handler
=
shmem_utils
.
ShmemCircularBufferReceiver
()
self
.
_shmem_handler
=
shmem_utils
.
ShmemCircularBufferReceiver
()
self
.
_stacking_friend
=
StackingFriend
(
self
.
get
(
"
merge
"
),
self
.
get
(
"
sources
"
))
self
.
_stacking_friend
=
StackingFriend
(
self
,
self
.
get
(
"
merge
"
),
self
.
get
(
"
sources
"
)
)
self
.
_frameselection_friend
=
FrameselectionFriend
(
self
.
get
(
"
frameSelector
"
))
self
.
_frameselection_friend
=
FrameselectionFriend
(
self
.
get
(
"
frameSelector
"
))
self
.
_thread_pool
=
concurrent
.
futures
.
ThreadPoolExecutor
(
self
.
_thread_pool
=
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
self
.
get
(
"
processingThreads
"
)
max_workers
=
self
.
get
(
"
processingThreads
"
)
...
@@ -102,22 +104,21 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -102,22 +104,21 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
def
on_matched_data
(
self
,
train_id
,
sources
):
def
on_matched_data
(
self
,
train_id
,
sources
):
frame_selection_mask
=
self
.
_frameselection_friend
.
get_mask
(
sources
)
frame_selection_mask
=
self
.
_frameselection_friend
.
get_mask
(
sources
)
# note: should not do stacking and frame selection for now!
# note: should not do stacking and frame selection for now!
self
.
_stacking_friend
.
prepare_stacking_for_train
(
sources
)
with
self
.
_stacking_friend
.
stacking_context
as
stacker
:
concurrent
.
futures
.
wait
(
concurrent
.
futures
.
wait
(
[
[
self
.
_thread_pool
.
submit
(
self
.
_thread_pool
.
submit
(
self
.
_handle_source
,
self
.
_handle_source
,
source
,
source
,
data
,
data
,
timestamp
,
timestamp
,
stacker
,
new_sources_map
,
frame_selection_mask
,
frame_selection_mask
,
)
)
for
source
,
(
data
,
timestamp
)
in
sources
.
items
()
for
source
,
(
data
,
timestamp
)
in
sources
.
items
()
]
]
)
)
sources
.
update
(
stacker
.
new_source_map
)
sources
.
update
(
new_sources_map
)
# karabo output
# karabo output
if
self
.
output
is
not
None
:
if
self
.
output
is
not
None
:
...
@@ -141,10 +142,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -141,10 +142,9 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
source
,
source
,
data_hash
,
data_hash
,
timestamp
,
timestamp
,
new_sources_map
,
stacker
,
frame_selection_mask
,
frame_selection_mask
,
ignore_stacking
,
):
):
self
.
_shmem_handler
.
dereference_shmem_handles
(
data_hash
)
self
.
_shmem_handler
.
dereference_shmem_handles
(
data_hash
)
self
.
_frameselection_friend
.
apply_mask
(
source
,
data_hash
,
frame_selection_mask
)
self
.
_frameselection_friend
.
apply_mask
(
source
,
data_hash
,
frame_selection_mask
)
s
elf
.
_stacking_friend
.
handle_source
(...
)
s
tacker
.
process
(
source
,
data_hash
)
This diff is collapsed.
Click to expand it.
src/calng/stacking_utils.py
+
139
−
123
View file @
c3e1c48f
import
collections
import
enum
import
enum
import
re
import
re
...
@@ -79,6 +80,22 @@ def merge_schema():
...
@@ -79,6 +80,22 @@ def merge_schema():
.
defaultValue
(
0
)
.
defaultValue
(
0
)
.
reconfigurable
()
.
reconfigurable
()
.
commit
(),
.
commit
(),
STRING_ELEMENT
(
schema
)
.
key
(
"
missingValue
"
)
.
displayedName
(
"
Missing value default
"
)
.
description
(
"
If some sources are missing within one group in multi-source stacking*,
"
"
the corresponding parts of the resulting stacked array will be set to
"
"
this value. Note that if no sources are present, the array is not created
"
"
at all. This field is a string to allow special values like float nan /
"
"
inf; it is your responsibility to make sure that data types match (i.e.
"
"
if this is
'
nan
'
, the stacked data better be floats or doubles). *Missing
"
"
value handling is not yet implementedo for multi-key stacking.
"
)
.
assignmentOptional
()
.
defaultValue
(
"
0
"
)
.
reconfigurable
()
.
commit
(),
)
)
return
schema
return
schema
...
@@ -121,147 +138,132 @@ class StackingFriend:
...
@@ -121,147 +138,132 @@ class StackingFriend:
.
commit
(),
.
commit
(),
)
)
def
__init__
(
self
,
merge_config
,
source_config
):
def
__init__
(
self
,
device
,
source_config
,
merge_config
):
self
.
_stacking_buffers
=
{}
self
.
_source_stacking_indices
=
{}
self
.
_source_stacking_indices
=
{}
self
.
_source_stacking_sources
=
{}
self
.
_source_stacking_sources
=
collections
.
defaultdict
(
list
)
self
.
_source_stacking_group_sizes
=
{}
self
.
_source_stacking_group_sizes
=
{}
self
.
_key_stacking_sources
=
{}
# (new source name, key) -> {original sources used}
self
.
_merge_config
=
Hash
()
self
.
_new_sources_inputs
=
collections
.
defaultdict
(
set
)
self
.
_source_config
=
Hash
()
self
.
_key_stacking_sources
=
collections
.
defaultdict
(
list
)
self
.
_merge_config
=
None
self
.
_source_config
=
None
self
.
_device
=
device
self
.
reconfigure
(
merge_config
,
source_config
)
self
.
reconfigure
(
merge_config
,
source_config
)
def
reconfigure
(
self
,
merge_config
,
source_config
):
def
reconfigure
(
self
,
merge_config
,
source_config
):
print
(
"
merge_config
"
,
type
(
merge_config
))
print
(
"
source_config
"
,
type
(
source_config
))
if
merge_config
is
not
None
:
if
merge_config
is
not
None
:
self
.
_merge_config
.
merge
(
merge_config
)
self
.
_merge_config
=
merge_config
if
source_config
is
not
None
:
if
source_config
is
not
None
:
self
.
_source_config
.
merge
(
source_config
)
self
.
_source_config
=
source_config
# not filtering by row["select"] to allow unselected sources to create gaps
source_names
=
[
row
[
"
source
"
].
partition
(
"
@
"
)[
0
]
for
row
in
self
.
_source_config
]
self
.
_stacking_buffers
.
clear
()
self
.
_source_stacking_indices
.
clear
()
self
.
_source_stacking_indices
.
clear
()
self
.
_source_stacking_sources
.
clear
()
self
.
_source_stacking_sources
.
clear
()
self
.
_source_stacking_group_sizes
.
clear
()
self
.
_source_stacking_group_sizes
.
clear
()
self
.
_key_stacking_sources
.
clear
()
self
.
_key_stacking_sources
.
clear
()
# split by type, prepare regexes
self
.
_new_sources_inputs
.
clear
()
for
row
in
self
.
_merge_config
:
if
not
row
[
"
select
"
]:
# not filtering by row["select"] to allow unselected sources to create gaps
continue
source_names
=
[
row
[
"
source
"
].
partition
(
"
@
"
)[
0
]
for
row
in
self
.
_source_config
]
group_type
=
GroupType
(
row
[
"
groupType
"
])
source_stacking_groups
=
[
row
for
row
in
self
.
_merge_config
if
row
[
"
select
"
]
and
row
[
"
groupType
"
]
==
GroupType
.
MULTISOURCE
.
name
]
key_stacking_groups
=
[
row
for
row
in
self
.
_merge_config
if
row
[
"
select
"
]
and
row
[
"
groupType
"
]
==
GroupType
.
MULTIKEY
.
name
]
for
row
in
source_stacking_groups
:
source_re
=
re
.
compile
(
row
[
"
sourcePattern
"
])
source_re
=
re
.
compile
(
row
[
"
sourcePattern
"
])
merge_method
=
MergeMethod
(
row
[
"
mergeMethod
"
])
merge_method
=
MergeMethod
(
row
[
"
mergeMethod
"
])
axis
=
row
[
"
axis
"
]
key
=
row
[
"
keyPattern
"
]
if
group_type
is
GroupType
.
MULTISOURCE
:
new_source
=
row
[
"
replacement
"
]
key
=
row
[
"
keyPattern
"
]
merge_sources
=
[
new_source
=
row
[
"
replacement
"
]
source
for
source
in
source_names
if
source_re
.
match
(
source
)
merge_sources
=
[
]
source
for
source
in
source_names
if
source_re
.
match
(
source
)
if
len
(
merge_sources
)
==
0
:
]
self
.
_device
.
log
.
WARN
(
if
len
(
merge_sources
)
==
0
:
f
"
Group pattern
{
source_re
}
did not match any known sources
"
self
.
log
.
WARN
(
f
"
Group pattern
{
source_re
}
did not match any known sources
"
)
continue
self
.
_source_stacking_group_sizes
[(
new_source
,
key
)]
=
len
(
merge_sources
)
)
for
i
,
source
in
enumerate
(
merge_sources
):
continue
self
.
_source_stacking_sources
.
setdefault
(
source
,
[]).
append
(
self
.
_source_stacking_group_sizes
[(
new_source
,
key
)]
=
len
(
merge_sources
)
(
key
,
new_source
,
merge_method
,
axis
)
for
i
,
source
in
enumerate
(
merge_sources
):
)
self
.
_source_stacking_sources
[
source
].
append
(
self
.
_source_stacking_indices
[(
source
,
new_source
,
key
)]
=
(
(
key
,
new_source
,
merge_method
,
row
[
"
axis
"
])
i
)
if
merge_method
is
MergeMethod
.
STACK
self
.
_source_stacking_indices
[(
source
,
new_source
,
key
)]
=
(
else
np
.
index_exp
[
i
slice
(
if
merge_method
is
MergeMethod
.
STACK
i
,
else
np
.
index_exp
[
# interleaving
None
,
slice
(
self
.
_source_stacking_group_sizes
[(
new_source
,
key
)],
i
,
)
None
,
]
len
(
merge_sources
),
)
)
else
:
]
key_re
=
re
.
compile
(
row
[
"
keyPattern
"
])
new_key
=
row
[
"
replacement
"
]
self
.
_key_stacking_sources
.
setdefault
(
source
,
[]).
append
(
(
key_re
,
new_key
,
merge_method
,
axis
)
)
)
def
prepare_stacking_for_train
(
for
row
in
key_stacking_groups
:
self
,
sources
,
frame_selection_mask
,
new_sources_map
key_re
=
re
.
compile
(
row
[
"
keyPattern
"
])
):
new_key
=
row
[
"
replacement
"
]
if
frame_selection_mask
is
not
None
:
self
.
_key_stacking_sources
[
source
].
append
(
orig_size
=
len
(
frame_selection_mask
)
(
key_re
,
new_key
,
MergeMethod
(
row
[
"
mergeMethod
"
]),
row
[
"
axis
"
]
)
result_size
=
np
.
sum
(
frame_selection_mask
)
)
def
process
(
self
,
sources
,
thread_pool
=
None
):
stacking_data_shapes
=
{}
stacking_data_shapes
=
{}
self
.
_ignore_stacking
=
{}
stacking_buffers
=
{}
self
.
_fill_missed_data
=
{}
new_source_map
=
collections
.
defaultdict
(
Hash
)
for
source
,
keys
in
self
.
_source_stacking_sources
.
items
():
missing_value_defaults
=
{}
if
source
not
in
sources
:
for
key
,
new_source
,
_
,
_
in
keys
:
# prepare for source stacking where sources are present
missed_sources
=
self
.
_fill_missed_data
.
setdefault
(
source_set
=
set
(
sources
.
keys
())
(
new_source
,
key
),
[]
for
(
new_source
,
data_key
,
merge_method
,
group_size
,
axis
,
missing_value
,
),
original_sources
in
self
.
_new_sources_inputs
.
items
():
for
present_source
in
source_set
&
original_sources
:
data
=
sources
[
present_source
].
get
(
data_key
)[
0
]
if
data
is
None
:
continue
if
merge_method
is
MergeMethod
.
STACK
:
expected_shape
=
utils
.
stacking_buffer_shape
(
data
.
shape
,
group_size
,
axis
=
axis
)
)
merge_index
=
self
.
_source_stacking_indices
[
(
source
,
new_source
,
key
)
]
missed_sources
.
append
(
merge_index
)
continue
data_hash
,
timestamp
=
sources
[
source
]
filtering
=
(
frame_selection_mask
is
not
None
and
self
.
_frame_selection_source_pattern
.
match
(
source
)
)
for
key
,
new_source
,
merge_method
,
axis
in
keys
:
merge_data_shape
=
None
if
key
in
data_hash
:
merge_data
=
data_hash
[
key
]
merge_data_shape
=
merge_data
.
shape
else
:
else
:
self
.
_ignore_stacking
[(
new_source
,
key
)]
=
"
Some data is missed
"
expected_shape
=
utils
.
interleaving_buffer_shape
(
continue
data
.
shape
,
group_size
,
axis
=
axis
)
if
filtering
and
key
in
self
.
_frame_selection_data_keys
:
stacking_buffer
=
np
.
empty
(
# !!! stacking is not expected to be used with filtering
expected_shape
,
dtype
=
data
.
dtype
if
merge_data_shape
[
0
]
==
orig_size
:
merge_data_shape
=
(
result_size
,)
+
merge_data
.
shape
[
1
:]
(
expected_shape
,
_
,
_
,
expected_dtype
,
_
,
)
=
stacking_data_shapes
.
setdefault
(
(
new_source
,
key
),
(
merge_data_shape
,
merge_method
,
axis
,
merge_data
.
dtype
,
timestamp
),
)
if
(
expected_shape
!=
merge_data_shape
or
expected_dtype
!=
merge_data
.
dtype
):
self
.
_ignore_stacking
[
(
new_source
,
key
)
]
=
"
Shape or dtype is inconsistent
"
del
stacking_data_shapes
[(
new_source
,
key
)]
for
(
new_source
,
key
),
msg
in
self
.
_ignore_stacking
.
items
():
self
.
_device
.
log
.
WARN
(
f
"
Failed to stack
{
new_source
}
.
{
key
}
:
{
msg
}
"
)
for
(
new_source
,
key
),
attr
in
stacking_data_shapes
.
items
():
merge_data_shape
,
merge_method
,
axis
,
dtype
,
timestamp
=
attr
group_size
=
self
.
_source_stacking_group_sizes
[(
new_source
,
key
)]
if
merge_method
is
MergeMethod
.
STACK
:
expected_shape
=
utils
.
stacking_buffer_shape
(
merge_data_shape
,
group_size
,
axis
=
axis
)
)
stacking_buffers
[(
new_source
,
data_key
)]
=
stacking_buffer
new_source_map
[
new_source
][
data_key
]
=
stacking_buffer
try
:
missing_value_defaults
[(
new_source
,
data_key
)]
=
data
.
dtype
.
type
(
missing_value
)
except
ValueError
:
self
.
_device
.
log
.
WARN
(
f
"
Invalid missing data value for
{
new_source
}
.
{
data_key
}
, using 0
"
)
break
else
:
else
:
expected_shape
=
utils
.
interleaving_buffer_shape
(
# in this case: no present_source (if any) had data_key
merge_data_shape
,
group_size
,
axis
=
axis
self
.
_device
.
log
.
WARN
(
f
"
No sources needed for
{
new_source
}
.
{
data_key
}
were present
"
)
)
for
(
new_source
,
key
),
attr
in
stacking_data_shapes
.
items
():
merge_data_shape
,
merge_method
,
axis
,
dtype
,
timestamp
=
attr
group_size
=
parent
.
_source_stacking_group_sizes
[(
new_source
,
key
)]
merge_buffer
=
self
.
_stacking_buffers
.
get
((
new_source
,
key
))
merge_buffer
=
self
.
_stacking_buffers
.
get
((
new_source
,
key
))
if
merge_buffer
is
None
or
merge_buffer
.
shape
!=
expected_shape
:
if
merge_buffer
is
None
or
merge_buffer
.
shape
!=
expected_shape
:
merge_buffer
=
np
.
empty
(
shape
=
expected_shape
,
dtype
=
dtype
)
merge_buffer
=
np
.
empty
(
shape
=
expected_shape
,
dtype
=
dtype
)
...
@@ -277,19 +279,33 @@ class StackingFriend:
...
@@ -277,19 +279,33 @@ class StackingFriend:
slice
(
slice
(
merge_index
,
merge_index
,
None
,
None
,
self
.
_source_stacking_group_sizes
[(
new_source
,
key
)],
parent
.
_source_stacking_group_sizes
[(
new_source
,
key
)],
)
)
],
],
axis
,
axis
,
)
)
if
new_source
not
in
new_source
s
_map
:
if
new_source
not
in
self
.
new_source_map
:
new_source
s
_map
[
new_source
]
=
(
Hash
(),
timestamp
)
self
.
new_source_map
[
new_source
]
=
(
Hash
(),
timestamp
)
new_source_hash
=
new_source
s
_map
[
new_source
][
0
]
new_source_hash
=
self
.
new_source_map
[
new_source
][
0
]
if
not
new_source_hash
.
has
(
key
):
if
not
new_source_hash
.
has
(
key
):
new_source_hash
[
key
]
=
merge_buffer
new_source_hash
[
key
]
=
merge_buffer
def
handle_source
(
self
,
source
,
data_hash
):
# now actually do some work
# stack across sources (many sources, same key)
fun
=
functools
.
partial
(
self
.
_handle_source
,
...)
if
thread_pool
is
None
:
for
_
in
map
(
fun
,
...):
pass
else
:
concurrent
.
futures
.
wait
(
thread_pool
.
map
(
fun
,
...))
def
_handle_expected_source
(
self
,
merge_buffers
,
missing_values
,
actual_sources
,
expected_source
):
"""
Helper function used in processing. Note that it should be called for each
source that was supposed to be there - so it can decide whether to move data in
for stacking, fill missing data in case a source is missing, or skip in case no
buffer was created (none of the necessary sources were present).
"""
if
expected_source
not
in
actual_sources
:
if
ex
for
(
for
(
key
,
key
,
new_source
,
new_source
,
...
...
This diff is collapsed.
Click to expand it.
David Hammer
@hammerd
mentioned in commit
209bb3f8
·
1 year ago
mentioned in commit
209bb3f8
mentioned in commit 209bb3f8766132fa7ceb70e0db9fa573e7b1a56b
Toggle commit list
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment