From 5040e702f66c8fca9c7b43fcc36624e2478a3d29 Mon Sep 17 00:00:00 2001 From: Thomas Kluyver <thomas@kluyver.me.uk> Date: Tue, 16 May 2023 13:02:08 +0100 Subject: [PATCH] Add reorder_axes function --- src/cal_tools/tools.py | 8 ++++++++ tests/test_cal_tools.py | 11 +++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/cal_tools/tools.py b/src/cal_tools/tools.py index de53ba77e..c71497d89 100644 --- a/src/cal_tools/tools.py +++ b/src/cal_tools/tools.py @@ -1030,3 +1030,11 @@ def write_compressed_frames( dataset.id.write_direct_chunk(chunk_start, compressed) return dataset + + +def reorder_axes(a, from_order, to_order): + assert len(from_order) == a.ndim + assert sorted(from_order) == sorted(to_order) + from_order = list(from_order) + order = tuple([from_order.index(lbl) for lbl in to_order]) + return a.transpose(order) diff --git a/tests/test_cal_tools.py b/tests/test_cal_tools.py index 36d9d3240..5929944c5 100644 --- a/tests/test_cal_tools.py +++ b/tests/test_cal_tools.py @@ -22,6 +22,7 @@ from cal_tools.tools import ( recursive_update, send_to_db, write_constants_fragment, + reorder_axes, ) # AGIPD operating conditions. @@ -614,3 +615,13 @@ def test_write_constants_fragment(tmp_path: Path): }, } } + + +def test_reorder_axes(): + a = np.zeros((10, 32, 256, 3)) + from_order = ('cells', 'slow_scan', 'fast_scan', 'gain') + to_order = ('slow_scan', 'fast_scan', 'cells', 'gain') + assert reorder_axes(a, from_order, to_order).shape == (32, 256, 10, 3) + + to_order = ('gain', 'fast_scan', 'slow_scan', 'cells') + assert reorder_axes(a, from_order, to_order).shape == (3, 256, 32, 10) -- GitLab