Skip to content
Snippets Groups Projects
Commit 66969e22 authored by Laurent Mercadier's avatar Laurent Mercadier
Browse files

Concatenate: simplifies sorting and add attrs of each run

parent 61e4a8d5
No related branches found
No related tags found
No related merge requests found
...@@ -453,7 +453,7 @@ def load(fields, runNB, proposalNB, subFolder='raw', display=False, validate=Fal ...@@ -453,7 +453,7 @@ def load(fields, runNB, proposalNB, subFolder='raw', display=False, validate=Fal
return result return result
def concatenateRuns(runs): def concatenateRuns(runs):
""" Concatenate a list of two runs with identical data variables along the """ Sorts and concatenate a list of runs with identical data variables along the
trainId dimension. trainId dimension.
Input: Input:
...@@ -461,12 +461,16 @@ def concatenateRuns(runs): ...@@ -461,12 +461,16 @@ def concatenateRuns(runs):
Output: Output:
a concatenated xarray Dataset a concatenated xarray Dataset
""" """
keys = sorted(runs[0].keys()) firstTid = {i: int(run.trainId[0].values) for i,run in enumerate(runs)}
for run in runs[1:]: orderedDict = dict(sorted(firstTid.items(), key=lambda t: t[1]))
if sorted(run.keys()) != keys: orderedRuns = [runs[i] for i in orderedDict]
keys = orderedRuns[0].keys()
for run in orderedRuns[1:]:
if run.keys() != keys:
print('data fields between different runs are not identical. Cannot combine runs.') print('data fields between different runs are not identical. Cannot combine runs.')
return return
result = xr.concat(runs, dim='trainId') result = xr.concat(orderedRuns, dim='trainId')
result = result.sortby(result.trainId) result.attrs['run'] = [run.attrs['run'] for run in orderedRuns]
result.attrs['runFolder'] = [run.attrs['runFolder'] for run in orderedRuns]
return result return result
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment