Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
C
calng
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
calibration
calng
Commits
47016a6d
Commit
47016a6d
authored
1 year ago
by
David Hammer
Browse files
Options
Downloads
Patches
Plain Diff
WIP: move stacking implementation to friend class
parent
f9a0d9b1
No related branches found
No related tags found
Tags containing commit
2 merge requests
!74
Refactor DetectorAssembler
,
!73
Refactor stacking for reuse and overlappability
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/calng/ShmemTrainMatcher.py
+21
-318
21 additions, 318 deletions
src/calng/ShmemTrainMatcher.py
src/calng/stacking_utils.py
+328
-0
328 additions, 0 deletions
src/calng/stacking_utils.py
with
349 additions
and
318 deletions
src/calng/ShmemTrainMatcher.py
+
21
−
318
View file @
47016a6d
import
concurrent.futures
import
concurrent.futures
import
enum
import
re
import
re
import
numpy
as
np
import
numpy
as
np
from
karabo.bound
import
(
from
karabo.bound
import
(
BOOL_ELEMENT
,
BOOL_ELEMENT
,
INT32_ELEMENT
,
KARABO_CLASSINFO
,
KARABO_CLASSINFO
,
NODE_ELEMENT
,
NODE_ELEMENT
,
OVERWRITE_ELEMENT
,
OVERWRITE_ELEMENT
,
STRING_ELEMENT
,
STRING_ELEMENT
,
TABLE_ELEMENT
,
VECTOR_STRING_ELEMENT
,
VECTOR_STRING_ELEMENT
,
ChannelMetaData
,
ChannelMetaData
,
Hash
,
Schema
,
State
,
State
,
)
)
from
TrainMatcher
import
TrainMatcher
from
TrainMatcher
import
TrainMatcher
from
.
import
shmem_utils
,
utils
from
.
import
shmem_utils
from
.stacking_utils
import
StackingFriend
from
._version
import
version
as
deviceVersion
from
._version
import
version
as
deviceVersion
class
GroupType
(
enum
.
Enum
):
MULTISOURCE
=
"
sources
"
# same key stacked from multiple sources in new source
MULTIKEY
=
"
keys
"
# multiple keys within each matched source is stacked in new key
class
MergeMethod
(
enum
.
Enum
):
STACK
=
"
stack
"
INTERLEAVE
=
"
interleave
"
def
merge_schema
():
schema
=
Schema
()
(
BOOL_ELEMENT
(
schema
)
.
key
(
"
select
"
)
.
displayedName
(
"
Select
"
)
.
assignmentOptional
()
.
defaultValue
(
False
)
.
reconfigurable
()
.
commit
(),
STRING_ELEMENT
(
schema
)
.
key
(
"
sourcePattern
"
)
.
displayedName
(
"
Source pattern
"
)
.
assignmentOptional
()
.
defaultValue
(
""
)
.
reconfigurable
()
.
commit
(),
STRING_ELEMENT
(
schema
)
.
key
(
"
keyPattern
"
)
.
displayedName
(
"
Key pattern
"
)
.
assignmentOptional
()
.
defaultValue
(
""
)
.
reconfigurable
()
.
commit
(),
STRING_ELEMENT
(
schema
)
.
key
(
"
replacement
"
)
.
displayedName
(
"
Replacement
"
)
.
assignmentOptional
()
.
defaultValue
(
""
)
.
reconfigurable
()
.
commit
(),
STRING_ELEMENT
(
schema
)
.
key
(
"
groupType
"
)
.
displayedName
(
"
Group type
"
)
.
options
(
"
,
"
.
join
(
option
.
value
for
option
in
GroupType
))
.
assignmentOptional
()
.
defaultValue
(
GroupType
.
MULTISOURCE
.
value
)
.
reconfigurable
()
.
commit
(),
STRING_ELEMENT
(
schema
)
.
key
(
"
mergeMethod
"
)
.
displayedName
(
"
Merge method
"
)
.
options
(
"
,
"
.
join
(
option
.
value
for
option
in
MergeMethod
))
.
assignmentOptional
()
.
defaultValue
(
MergeMethod
.
STACK
.
value
)
.
reconfigurable
()
.
commit
(),
INT32_ELEMENT
(
schema
)
.
key
(
"
axis
"
)
.
displayedName
(
"
Axis
"
)
.
assignmentOptional
()
.
defaultValue
(
0
)
.
reconfigurable
()
.
commit
(),
)
return
schema
@KARABO_CLASSINFO
(
"
ShmemTrainMatcher
"
,
deviceVersion
)
@KARABO_CLASSINFO
(
"
ShmemTrainMatcher
"
,
deviceVersion
)
class
ShmemTrainMatcher
(
TrainMatcher
.
TrainMatcher
):
class
ShmemTrainMatcher
(
TrainMatcher
.
TrainMatcher
):
@staticmethod
@staticmethod
def
expectedParameters
(
expected
):
def
expectedParameters
(
expected
):
(
(
TABLE_ELEMENT
(
expected
)
.
key
(
"
merge
"
)
.
displayedName
(
"
Array stacking
"
)
.
allowedStates
(
State
.
PASSIVE
)
.
description
(
"
Specify which source(s) or key(s) to stack or interleave.
"
"
When stacking sources, the
'
Source pattern
'
is interpreted as a
"
"
regular expression and the
'
Key pattern
'
is interpreted as an
"
"
ordinary string. From all sources matching the source pattern, the
"
"
data under this key (should be array with same dimensions across all
"
"
stacked sources) is stacked in the same order as the sources are
"
"
listed in
'
Data sources
'
and the result is under the same key name in
"
"
a new source named by
'
Replacement
'
.
"
"
When stacking keys, both the
'
Source pattern
'
and the
'
Key pattern
'
"
"
are regular expressions. Within each source matching the source
"
"
pattern, all keys matching the key pattern are stacked and the result
"
"
is put under the key named by
'
Replacement
'
.
"
"
While source stacking is optimized and can use thread pool, key
"
"
stacking will iterate over all paths in matched sources and naively
"
"
call np.stack for each key pattern. In either case, data that is used
"
"
for stacking is removed from its original location (e.g. key is
"
"
erased from hash).
"
"
In both cases, the data can alternatively be interleaved. This is
"
"
essentially equivalent to stacking except followed by a reshape such
"
"
that the output is shaped like concatenation.
"
)
.
setColumns
(
merge_schema
())
.
assignmentOptional
()
.
defaultValue
([])
.
reconfigurable
()
.
commit
(),
# order is important for stacking, disable sorting
# order is important for stacking, disable sorting
OVERWRITE_ELEMENT
(
expected
)
OVERWRITE_ELEMENT
(
expected
)
.
key
(
"
sortSources
"
)
.
key
(
"
sortSources
"
)
...
@@ -217,6 +106,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -217,6 +106,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
.
reconfigurable
()
.
reconfigurable
()
.
commit
(),
.
commit
(),
)
)
StackingFriend
.
add_schema
(
expected
)
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
if
config
.
get
(
"
useInfiniband
"
,
default
=
True
):
if
config
.
get
(
"
useInfiniband
"
,
default
=
True
):
...
@@ -226,14 +116,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -226,14 +116,8 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
def
initialization
(
self
):
def
initialization
(
self
):
self
.
_stacking_buffers
=
{}
self
.
_source_stacking_indices
=
{}
self
.
_source_stacking_sources
=
{}
self
.
_source_stacking_group_sizes
=
{}
self
.
_key_stacking_sources
=
{}
self
.
_have_prepared_merge_groups
=
False
self
.
_prepare_merge_groups
()
super
().
initialization
()
super
().
initialization
()
self
.
_stacking_friend
=
StackingFriend
(
self
.
get
(
"
merge
"
),
self
.
get
(
"
sources
"
))
self
.
_shmem_handler
=
shmem_utils
.
ShmemCircularBufferReceiver
()
self
.
_shmem_handler
=
shmem_utils
.
ShmemCircularBufferReceiver
()
if
self
.
get
(
"
useThreadPool
"
):
if
self
.
get
(
"
useThreadPool
"
):
self
.
_thread_pool
=
concurrent
.
futures
.
ThreadPoolExecutor
()
self
.
_thread_pool
=
concurrent
.
futures
.
ThreadPoolExecutor
()
...
@@ -256,7 +140,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -256,7 +140,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
def
preReconfigure
(
self
,
conf
):
def
preReconfigure
(
self
,
conf
):
super
().
preReconfigure
(
conf
)
super
().
preReconfigure
(
conf
)
if
conf
.
has
(
"
merge
"
)
or
conf
.
has
(
"
sources
"
):
if
conf
.
has
(
"
merge
"
)
or
conf
.
has
(
"
sources
"
):
self
.
_
have_prepared_merge_groups
=
False
self
.
_
stacking_friend
.
reconfigure
(
conf
.
get
(
"
merge
"
),
conf
.
get
(
"
sources
"
))
# re-prepare in postReconfigure after sources *and* merge are in self
# re-prepare in postReconfigure after sources *and* merge are in self
if
conf
.
has
(
"
useThreadPool
"
):
if
conf
.
has
(
"
useThreadPool
"
):
if
self
.
_thread_pool
is
not
None
:
if
self
.
_thread_pool
is
not
None
:
...
@@ -274,8 +158,6 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -274,8 +158,6 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
def
postReconfigure
(
self
):
def
postReconfigure
(
self
):
super
().
postReconfigure
()
super
().
postReconfigure
()
if
not
self
.
_have_prepared_merge_groups
:
self
.
_prepare_merge_groups
()
if
not
self
.
_have_prepared_frame_selection
:
if
not
self
.
_have_prepared_frame_selection
:
self
.
_prepare_frame_selection
()
self
.
_prepare_frame_selection
()
...
@@ -287,136 +169,14 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -287,136 +169,14 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
)
)
self
.
_frame_selection_data_keys
=
list
(
self
.
get
(
"
frameSelector.dataKeys
"
))
self
.
_frame_selection_data_keys
=
list
(
self
.
get
(
"
frameSelector.dataKeys
"
))
def
_prepare_merge_groups
(
self
):
# not filtering by row["select"] to allow unselected sources to create gaps
source_names
=
[
row
[
"
source
"
].
partition
(
"
@
"
)[
0
]
for
row
in
self
.
get
(
"
sources
"
)]
self
.
_stacking_buffers
.
clear
()
self
.
_source_stacking_indices
.
clear
()
self
.
_source_stacking_sources
.
clear
()
self
.
_source_stacking_group_sizes
.
clear
()
self
.
_key_stacking_sources
.
clear
()
# split by type, prepare regexes
for
row
in
self
.
get
(
"
merge
"
):
if
not
row
[
"
select
"
]:
continue
group_type
=
GroupType
(
row
[
"
groupType
"
])
source_re
=
re
.
compile
(
row
[
"
sourcePattern
"
])
merge_method
=
MergeMethod
(
row
[
"
mergeMethod
"
])
axis
=
row
[
"
axis
"
]
if
group_type
is
GroupType
.
MULTISOURCE
:
key
=
row
[
"
keyPattern
"
]
new_source
=
row
[
"
replacement
"
]
merge_sources
=
[
source
for
source
in
source_names
if
source_re
.
match
(
source
)
]
if
len
(
merge_sources
)
==
0
:
self
.
log
.
WARN
(
f
"
Group pattern
{
source_re
}
did not match any known sources
"
)
continue
for
(
i
,
source
)
in
enumerate
(
merge_sources
):
self
.
_source_stacking_sources
.
setdefault
(
source
,
[]).
append
(
(
key
,
new_source
,
merge_method
,
axis
)
)
self
.
_source_stacking_indices
[(
source
,
new_source
,
key
)]
=
i
self
.
_source_stacking_group_sizes
[(
new_source
,
key
)]
=
i
+
1
else
:
key_re
=
re
.
compile
(
row
[
"
keyPattern
"
])
new_key
=
row
[
"
replacement
"
]
self
.
_key_stacking_sources
.
setdefault
(
source
,
[]).
append
(
(
key_re
,
new_key
,
merge_method
,
axis
)
)
self
.
_have_prepared_merge_groups
=
True
def
_check_stacking_data
(
self
,
sources
,
frame_selection_mask
):
if
frame_selection_mask
is
not
None
:
orig_size
=
len
(
frame_selection_mask
)
result_size
=
np
.
sum
(
frame_selection_mask
)
stacking_data_shapes
=
{}
ignore_stacking
=
{}
fill_missed_data
=
{}
for
source
,
keys
in
self
.
_source_stacking_sources
.
items
():
if
source
not
in
sources
:
for
key
,
new_source
,
_
,
_
in
keys
:
missed_sources
=
fill_missed_data
.
setdefault
((
new_source
,
key
),
[])
merge_index
=
self
.
_source_stacking_indices
[
(
source
,
new_source
,
key
)
]
missed_sources
.
append
(
merge_index
)
continue
data_hash
,
timestamp
=
sources
[
source
]
filtering
=
(
frame_selection_mask
is
not
None
and
self
.
_frame_selection_source_pattern
.
match
(
source
)
)
for
key
,
new_source
,
merge_method
,
axis
in
keys
:
merge_data_shape
=
None
if
key
in
data_hash
:
merge_data
=
data_hash
[
key
]
merge_data_shape
=
merge_data
.
shape
else
:
ignore_stacking
[(
new_source
,
key
)]
=
"
Some data is missed
"
continue
if
filtering
and
key
in
self
.
_frame_selection_data_keys
:
# !!! stacking is not expected to be used with filtering
if
merge_data_shape
[
0
]
==
orig_size
:
merge_data_shape
=
(
result_size
,)
+
merge_data
.
shape
[
1
:]
expected_shape
,
_
,
_
,
expected_dtype
,
_
=
stacking_data_shapes
.
setdefault
(
(
new_source
,
key
),
(
merge_data_shape
,
merge_method
,
axis
,
merge_data
.
dtype
,
timestamp
)
)
if
expected_shape
!=
merge_data_shape
or
expected_dtype
!=
merge_data
.
dtype
:
ignore_stacking
[(
new_source
,
key
)]
=
"
Shape or dtype is inconsistent
"
del
stacking_data_shapes
[(
new_source
,
key
)]
return
stacking_data_shapes
,
ignore_stacking
,
fill_missed_data
def
_maybe_update_stacking_buffers
(
self
,
stacking_data_shapes
,
fill_missed_data
,
new_sources_map
):
for
(
new_source
,
key
),
attr
in
stacking_data_shapes
.
items
():
merge_data_shape
,
merge_method
,
axis
,
dtype
,
timestamp
=
attr
group_size
=
self
.
_source_stacking_group_sizes
[(
new_source
,
key
)]
if
merge_method
is
MergeMethod
.
STACK
:
expected_shape
=
utils
.
stacking_buffer_shape
(
merge_data_shape
,
group_size
,
axis
=
axis
)
else
:
expected_shape
=
utils
.
interleaving_buffer_shape
(
merge_data_shape
,
group_size
,
axis
=
axis
)
merge_buffer
=
self
.
_stacking_buffers
.
get
((
new_source
,
key
))
if
merge_buffer
is
None
or
merge_buffer
.
shape
!=
expected_shape
:
merge_buffer
=
np
.
empty
(
shape
=
expected_shape
,
dtype
=
dtype
)
self
.
_stacking_buffers
[(
new_source
,
key
)]
=
merge_buffer
for
merge_index
in
fill_missed_data
.
get
((
new_source
,
key
),
[]):
utils
.
set_on_axis
(
merge_buffer
,
0
,
merge_index
if
merge_method
is
MergeMethod
.
STACK
else
np
.
index_exp
[
slice
(
merge_index
,
None
,
self
.
_source_stacking_group_sizes
[(
new_source
,
key
)],
)
],
axis
,
)
if
new_source
not
in
new_sources_map
:
new_sources_map
[
new_source
]
=
(
Hash
(),
timestamp
)
new_source_hash
=
new_sources_map
[
new_source
][
0
]
if
not
new_source_hash
.
has
(
key
):
new_source_hash
[
key
]
=
merge_buffer
def
_handle_source
(
def
_handle_source
(
self
,
source
,
data_hash
,
timestamp
,
new_sources_map
,
frame_selection_mask
,
self
,
ignore_stacking
source
,
data_hash
,
timestamp
,
new_sources_map
,
frame_selection_mask
,
ignore_stacking
,
):
):
# dereference calng shmem handles
# dereference calng shmem handles
self
.
_shmem_handler
.
dereference_shmem_handles
(
data_hash
)
self
.
_shmem_handler
.
dereference_shmem_handles
(
data_hash
)
...
@@ -434,68 +194,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -434,68 +194,7 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
data_hash
[
key
]
=
data_hash
[
key
][
frame_selection_mask
]
data_hash
[
key
]
=
data_hash
[
key
][
frame_selection_mask
]
# stack across sources (many sources, same key)
self
.
_stacking_friend
.
handle_source
(...)
for
(
key
,
new_source
,
merge_method
,
axis
,
)
in
self
.
_source_stacking_sources
.
get
(
source
,
()):
if
(
new_source
,
key
)
in
ignore_stacking
:
continue
merge_data
=
data_hash
.
get
(
key
)
merge_index
=
self
.
_source_stacking_indices
[
(
source
,
new_source
,
key
)
]
merge_buffer
=
self
.
_stacking_buffers
[(
new_source
,
key
)]
utils
.
set_on_axis
(
merge_buffer
,
merge_data
,
merge_index
if
merge_method
is
MergeMethod
.
STACK
else
np
.
index_exp
[
slice
(
merge_index
,
None
,
self
.
_source_stacking_group_sizes
[(
new_source
,
key
)],
)
],
axis
,
)
data_hash
.
erase
(
key
)
# stack keys (multiple keys within this source)
for
(
key_re
,
new_key
,
merge_method
,
axis
)
in
self
.
_key_stacking_sources
.
get
(
source
,
()
):
# note: please no overlap between different key_re
# note: if later key_re match earlier new_key, this gets spicy
keys
=
[
key
for
key
in
data_hash
.
paths
()
if
key_re
.
match
(
key
)]
try
:
# note: maybe we could reuse buffers here, too?
if
merge_method
is
MergeMethod
.
STACK
:
stacked
=
np
.
stack
([
data_hash
.
get
(
key
)
for
key
in
keys
],
axis
=
axis
)
else
:
first
=
data_hash
.
get
(
keys
[
0
])
stacked
=
np
.
empty
(
shape
=
utils
.
stacking_buffer_shape
(
first
.
shape
,
len
(
keys
),
axis
=
axis
),
dtype
=
first
.
dtype
,
)
for
i
,
key
in
enumerate
(
keys
):
utils
.
set_on_axis
(
stacked
,
data_hash
.
get
(
key
),
np
.
index_exp
[
slice
(
i
,
None
,
len
(
keys
))],
axis
,
)
except
Exception
as
e
:
self
.
log
.
WARN
(
f
"
Failed to stack
{
key_re
}
for
{
source
}
:
{
e
}
"
)
else
:
for
key
in
keys
:
data_hash
.
erase
(
key
)
data_hash
[
new_key
]
=
stacked
def
on_matched_data
(
self
,
train_id
,
sources
):
def
on_matched_data
(
self
,
train_id
,
sources
):
new_sources_map
=
{}
new_sources_map
=
{}
...
@@ -508,11 +207,15 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
...
@@ -508,11 +207,15 @@ class ShmemTrainMatcher(TrainMatcher.TrainMatcher):
).
astype
(
np
.
bool
,
copy
=
False
),
).
astype
(
np
.
bool
,
copy
=
False
),
# prepare stacking
# prepare stacking
stacking_data_shapes
,
ignore_stacking
,
fill_missed_data
=
(
(
self
.
_check_stacking_data
(
sources
,
frame_selection_mask
)
stacking_data_shapes
,
ignore_stacking
,
fill_missed_data
,
)
=
self
.
_stacking_friend
.
check_stacking_data
(
sources
,
frame_selection_mask
)
self
.
_stacking_friend
.
maybe_update_stacking_buffers
(
stacking_data_shapes
,
fill_missed_data
,
new_sources_map
)
)
self
.
_maybe_update_stacking_buffers
(
stacking_data_shapes
,
fill_missed_data
,
new_sources_map
)
for
(
new_source
,
key
),
msg
in
ignore_stacking
.
items
():
for
(
new_source
,
key
),
msg
in
ignore_stacking
.
items
():
self
.
log
.
WARN
(
f
"
Failed to stack
{
new_source
}
.
{
key
}
:
{
msg
}
"
)
self
.
log
.
WARN
(
f
"
Failed to stack
{
new_source
}
.
{
key
}
:
{
msg
}
"
)
...
...
This diff is collapsed.
Click to expand it.
src/calng/stacking_utils.py
0 → 100644
+
328
−
0
View file @
47016a6d
import
enum
import
re
from
karabo.bound
import
(
BOOL_ELEMENT
,
INT32_ELEMENT
,
STRING_ELEMENT
,
TABLE_ELEMENT
,
Hash
,
Schema
,
)
class
GroupType
(
enum
.
Enum
):
MULTISOURCE
=
"
sources
"
# same key stacked from multiple sources in new source
MULTIKEY
=
"
keys
"
# multiple keys within each matched source is stacked in new key
class
MergeMethod
(
enum
.
Enum
):
STACK
=
"
stack
"
INTERLEAVE
=
"
interleave
"
def
merge_schema
():
schema
=
Schema
()
(
BOOL_ELEMENT
(
schema
)
.
key
(
"
select
"
)
.
displayedName
(
"
Select
"
)
.
assignmentOptional
()
.
defaultValue
(
False
)
.
reconfigurable
()
.
commit
(),
STRING_ELEMENT
(
schema
)
.
key
(
"
sourcePattern
"
)
.
displayedName
(
"
Source pattern
"
)
.
assignmentOptional
()
.
defaultValue
(
""
)
.
reconfigurable
()
.
commit
(),
STRING_ELEMENT
(
schema
)
.
key
(
"
keyPattern
"
)
.
displayedName
(
"
Key pattern
"
)
.
assignmentOptional
()
.
defaultValue
(
""
)
.
reconfigurable
()
.
commit
(),
STRING_ELEMENT
(
schema
)
.
key
(
"
replacement
"
)
.
displayedName
(
"
Replacement
"
)
.
assignmentOptional
()
.
defaultValue
(
""
)
.
reconfigurable
()
.
commit
(),
STRING_ELEMENT
(
schema
)
.
key
(
"
groupType
"
)
.
displayedName
(
"
Group type
"
)
.
options
(
"
,
"
.
join
(
option
.
value
for
option
in
GroupType
))
.
assignmentOptional
()
.
defaultValue
(
GroupType
.
MULTISOURCE
.
value
)
.
reconfigurable
()
.
commit
(),
STRING_ELEMENT
(
schema
)
.
key
(
"
mergeMethod
"
)
.
displayedName
(
"
Merge method
"
)
.
options
(
"
,
"
.
join
(
option
.
value
for
option
in
MergeMethod
))
.
assignmentOptional
()
.
defaultValue
(
MergeMethod
.
STACK
.
value
)
.
reconfigurable
()
.
commit
(),
INT32_ELEMENT
(
schema
)
.
key
(
"
axis
"
)
.
displayedName
(
"
Axis
"
)
.
assignmentOptional
()
.
defaultValue
(
0
)
.
reconfigurable
()
.
commit
(),
)
return
schema
class
StackingFriend
:
@staticmethod
def
add_schema
(
self
,
schema
):
(
TABLE_ELEMENT
(
expected
)
.
key
(
"
merge
"
)
.
displayedName
(
"
Array stacking
"
)
.
allowedStates
(
State
.
PASSIVE
)
.
description
(
"
Specify which source(s) or key(s) to stack or interleave.
"
"
When stacking sources, the
'
Source pattern
'
is interpreted as a
"
"
regular expression and the
'
Key pattern
'
is interpreted as an
"
"
ordinary string. From all sources matching the source pattern, the
"
"
data under this key (should be array with same dimensions across all
"
"
stacked sources) is stacked in the same order as the sources are
"
"
listed in
'
Data sources
'
and the result is under the same key name in
"
"
a new source named by
'
Replacement
'
.
"
"
When stacking keys, both the
'
Source pattern
'
and the
'
Key pattern
'
"
"
are regular expressions. Within each source matching the source
"
"
pattern, all keys matching the key pattern are stacked and the result
"
"
is put under the key named by
'
Replacement
'
.
"
"
While source stacking is optimized and can use thread pool, key
"
"
stacking will iterate over all paths in matched sources and naively
"
"
call np.stack for each key pattern. In either case, data that is used
"
"
for stacking is removed from its original location (e.g. key is
"
"
erased from hash).
"
"
In both cases, the data can alternatively be interleaved. This is
"
"
essentially equivalent to stacking except followed by a reshape such
"
"
that the output is shaped like concatenation.
"
)
.
setColumns
(
merge_schema
())
.
assignmentOptional
()
.
defaultValue
([])
.
reconfigurable
()
.
commit
(),
)
def
__init__
(
self
,
merge_config
,
source_config
):
self
.
_stacking_buffers
=
{}
self
.
_source_stacking_indices
=
{}
self
.
_source_stacking_sources
=
{}
self
.
_source_stacking_group_sizes
=
{}
self
.
_key_stacking_sources
=
{}
self
.
_merge_config
=
Hash
()
self
.
_source_config
=
Hash
()
self
.
reconfigure
(
merge_config
,
source_config
)
def
reconfigure
(
self
,
merge_config
,
source_config
):
if
merge_config
is
not
None
:
self
.
_merge_config
.
merge
(
merge_config
)
if
source_config
is
not
None
:
self
.
_source_config
.
merge
(
source_config
)
# not filtering by row["select"] to allow unselected sources to create gaps
source_names
=
[
row
[
"
source
"
].
partition
(
"
@
"
)[
0
]
for
row
in
self
.
_source_config
]
self
.
_stacking_buffers
.
clear
()
self
.
_source_stacking_indices
.
clear
()
self
.
_source_stacking_sources
.
clear
()
self
.
_source_stacking_group_sizes
.
clear
()
self
.
_key_stacking_sources
.
clear
()
# split by type, prepare regexes
for
row
in
self
.
_merge_config
:
if
not
row
[
"
select
"
]:
continue
group_type
=
GroupType
(
row
[
"
groupType
"
])
source_re
=
re
.
compile
(
row
[
"
sourcePattern
"
])
merge_method
=
MergeMethod
(
row
[
"
mergeMethod
"
])
axis
=
row
[
"
axis
"
]
if
group_type
is
GroupType
.
MULTISOURCE
:
key
=
row
[
"
keyPattern
"
]
new_source
=
row
[
"
replacement
"
]
merge_sources
=
[
source
for
source
in
source_names
if
source_re
.
match
(
source
)
]
if
len
(
merge_sources
)
==
0
:
self
.
log
.
WARN
(
f
"
Group pattern
{
source_re
}
did not match any known sources
"
)
continue
for
(
i
,
source
)
in
enumerate
(
merge_sources
):
self
.
_source_stacking_sources
.
setdefault
(
source
,
[]).
append
(
(
key
,
new_source
,
merge_method
,
axis
)
)
self
.
_source_stacking_indices
[(
source
,
new_source
,
key
)]
=
i
self
.
_source_stacking_group_sizes
[(
new_source
,
key
)]
=
i
+
1
else
:
key_re
=
re
.
compile
(
row
[
"
keyPattern
"
])
new_key
=
row
[
"
replacement
"
]
self
.
_key_stacking_sources
.
setdefault
(
source
,
[]).
append
(
(
key_re
,
new_key
,
merge_method
,
axis
)
)
def
check_stacking_data
(
self
,
sources
,
frame_selection_mask
):
if
frame_selection_mask
is
not
None
:
orig_size
=
len
(
frame_selection_mask
)
result_size
=
np
.
sum
(
frame_selection_mask
)
stacking_data_shapes
=
{}
ignore_stacking
=
{}
fill_missed_data
=
{}
for
source
,
keys
in
self
.
_source_stacking_sources
.
items
():
if
source
not
in
sources
:
for
key
,
new_source
,
_
,
_
in
keys
:
missed_sources
=
fill_missed_data
.
setdefault
((
new_source
,
key
),
[])
merge_index
=
self
.
_source_stacking_indices
[
(
source
,
new_source
,
key
)
]
missed_sources
.
append
(
merge_index
)
continue
data_hash
,
timestamp
=
sources
[
source
]
filtering
=
(
frame_selection_mask
is
not
None
and
self
.
_frame_selection_source_pattern
.
match
(
source
)
)
for
key
,
new_source
,
merge_method
,
axis
in
keys
:
merge_data_shape
=
None
if
key
in
data_hash
:
merge_data
=
data_hash
[
key
]
merge_data_shape
=
merge_data
.
shape
else
:
ignore_stacking
[(
new_source
,
key
)]
=
"
Some data is missed
"
continue
if
filtering
and
key
in
self
.
_frame_selection_data_keys
:
# !!! stacking is not expected to be used with filtering
if
merge_data_shape
[
0
]
==
orig_size
:
merge_data_shape
=
(
result_size
,)
+
merge_data
.
shape
[
1
:]
expected_shape
,
_
,
_
,
expected_dtype
,
_
=
stacking_data_shapes
.
setdefault
(
(
new_source
,
key
),
(
merge_data_shape
,
merge_method
,
axis
,
merge_data
.
dtype
,
timestamp
)
)
if
expected_shape
!=
merge_data_shape
or
expected_dtype
!=
merge_data
.
dtype
:
ignore_stacking
[(
new_source
,
key
)]
=
"
Shape or dtype is inconsistent
"
del
stacking_data_shapes
[(
new_source
,
key
)]
return
stacking_data_shapes
,
ignore_stacking
,
fill_missed_data
def
maybe_update_stacking_buffers
(
self
,
stacking_data_shapes
,
fill_missed_data
,
new_sources_map
):
for
(
new_source
,
key
),
attr
in
stacking_data_shapes
.
items
():
merge_data_shape
,
merge_method
,
axis
,
dtype
,
timestamp
=
attr
group_size
=
self
.
_source_stacking_group_sizes
[(
new_source
,
key
)]
if
merge_method
is
MergeMethod
.
STACK
:
expected_shape
=
utils
.
stacking_buffer_shape
(
merge_data_shape
,
group_size
,
axis
=
axis
)
else
:
expected_shape
=
utils
.
interleaving_buffer_shape
(
merge_data_shape
,
group_size
,
axis
=
axis
)
merge_buffer
=
self
.
_stacking_buffers
.
get
((
new_source
,
key
))
if
merge_buffer
is
None
or
merge_buffer
.
shape
!=
expected_shape
:
merge_buffer
=
np
.
empty
(
shape
=
expected_shape
,
dtype
=
dtype
)
self
.
_stacking_buffers
[(
new_source
,
key
)]
=
merge_buffer
for
merge_index
in
fill_missed_data
.
get
((
new_source
,
key
),
[]):
utils
.
set_on_axis
(
merge_buffer
,
0
,
merge_index
if
merge_method
is
MergeMethod
.
STACK
else
np
.
index_exp
[
slice
(
merge_index
,
None
,
self
.
_source_stacking_group_sizes
[(
new_source
,
key
)],
)
],
axis
,
)
if
new_source
not
in
new_sources_map
:
new_sources_map
[
new_source
]
=
(
Hash
(),
timestamp
)
new_source_hash
=
new_sources_map
[
new_source
][
0
]
if
not
new_source_hash
.
has
(
key
):
new_source_hash
[
key
]
=
merge_buffer
def
handle_source
(...):
# stack across sources (many sources, same key)
for
(
key
,
new_source
,
merge_method
,
axis
,
)
in
self
.
_source_stacking_sources
.
get
(
source
,
()):
if
(
new_source
,
key
)
in
ignore_stacking
:
continue
merge_data
=
data_hash
.
get
(
key
)
merge_index
=
self
.
_source_stacking_indices
[
(
source
,
new_source
,
key
)
]
merge_buffer
=
self
.
_stacking_buffers
[(
new_source
,
key
)]
utils
.
set_on_axis
(
merge_buffer
,
merge_data
,
merge_index
if
merge_method
is
MergeMethod
.
STACK
else
np
.
index_exp
[
slice
(
merge_index
,
None
,
self
.
_source_stacking_group_sizes
[(
new_source
,
key
)],
)
],
axis
,
)
data_hash
.
erase
(
key
)
# stack keys (multiple keys within this source)
for
(
key_re
,
new_key
,
merge_method
,
axis
)
in
self
.
_key_stacking_sources
.
get
(
source
,
()
):
# note: please no overlap between different key_re
# note: if later key_re match earlier new_key, this gets spicy
keys
=
[
key
for
key
in
data_hash
.
paths
()
if
key_re
.
match
(
key
)]
try
:
# note: maybe we could reuse buffers here, too?
if
merge_method
is
MergeMethod
.
STACK
:
stacked
=
np
.
stack
([
data_hash
.
get
(
key
)
for
key
in
keys
],
axis
=
axis
)
else
:
first
=
data_hash
.
get
(
keys
[
0
])
stacked
=
np
.
empty
(
shape
=
utils
.
stacking_buffer_shape
(
first
.
shape
,
len
(
keys
),
axis
=
axis
),
dtype
=
first
.
dtype
,
)
for
i
,
key
in
enumerate
(
keys
):
utils
.
set_on_axis
(
stacked
,
data_hash
.
get
(
key
),
np
.
index_exp
[
slice
(
i
,
None
,
len
(
keys
))],
axis
,
)
except
Exception
as
e
:
self
.
log
.
WARN
(
f
"
Failed to stack
{
key_re
}
for
{
source
}
:
{
e
}
"
)
else
:
for
key
in
keys
:
data_hash
.
erase
(
key
)
data_hash
[
new_key
]
=
stacked
This diff is collapsed.
Click to expand it.
David Hammer
@hammerd
mentioned in commit
209bb3f8
·
1 year ago
mentioned in commit
209bb3f8
mentioned in commit 209bb3f8766132fa7ceb70e0db9fa573e7b1a56b
Toggle commit list
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment