From 5040e702f66c8fca9c7b43fcc36624e2478a3d29 Mon Sep 17 00:00:00 2001
From: Thomas Kluyver <thomas@kluyver.me.uk>
Date: Tue, 16 May 2023 13:02:08 +0100
Subject: [PATCH] Add reorder_axes function

---
 src/cal_tools/tools.py  |  8 ++++++++
 tests/test_cal_tools.py | 11 +++++++++++
 2 files changed, 19 insertions(+)

diff --git a/src/cal_tools/tools.py b/src/cal_tools/tools.py
index de53ba77e..c71497d89 100644
--- a/src/cal_tools/tools.py
+++ b/src/cal_tools/tools.py
@@ -1030,3 +1030,11 @@ def write_compressed_frames(
             dataset.id.write_direct_chunk(chunk_start, compressed)
 
     return dataset
+
+
+def reorder_axes(a, from_order, to_order):
+    assert len(from_order) == a.ndim
+    assert sorted(from_order) == sorted(to_order)
+    from_order = list(from_order)
+    order = tuple([from_order.index(lbl) for lbl in to_order])
+    return a.transpose(order)
diff --git a/tests/test_cal_tools.py b/tests/test_cal_tools.py
index 36d9d3240..5929944c5 100644
--- a/tests/test_cal_tools.py
+++ b/tests/test_cal_tools.py
@@ -22,6 +22,7 @@ from cal_tools.tools import (
     recursive_update,
     send_to_db,
     write_constants_fragment,
+    reorder_axes,
 )
 
 # AGIPD operating conditions.
@@ -614,3 +615,13 @@ def test_write_constants_fragment(tmp_path: Path):
                 },
             }
         }
+
+
+def test_reorder_axes():
+    a = np.zeros((10, 32, 256, 3))
+    from_order = ('cells', 'slow_scan', 'fast_scan', 'gain')
+    to_order = ('slow_scan', 'fast_scan', 'cells', 'gain')
+    assert reorder_axes(a, from_order, to_order).shape == (32, 256, 10, 3)
+
+    to_order = ('gain', 'fast_scan', 'slow_scan', 'cells')
+    assert reorder_axes(a, from_order, to_order).shape == (3, 256, 32, 10)
-- 
GitLab