From 10f65c4973c900ee6dea6583dedd5c08d2f25e87 Mon Sep 17 00:00:00 2001
From: Philipp Schmidt <philipp.schmidt@xfel.eu>
Date: Wed, 21 Feb 2024 16:45:23 +0100
Subject: [PATCH] Throw pasha at timepix centroiding

---
 .../Compute_Timepix_Event_Centroids.ipynb     | 71 +++++++++++--------
 1 file changed, 40 insertions(+), 31 deletions(-)

diff --git a/notebooks/Timepix/Compute_Timepix_Event_Centroids.ipynb b/notebooks/Timepix/Compute_Timepix_Event_Centroids.ipynb
index 8a713ab8e..50a985dfd 100755
--- a/notebooks/Timepix/Compute_Timepix_Event_Centroids.ipynb
+++ b/notebooks/Timepix/Compute_Timepix_Event_Centroids.ipynb
@@ -62,11 +62,13 @@
     "from datetime import datetime\n",
     "from pathlib import Path\n",
     "from time import monotonic\n",
+    "from os import cpu_count\n",
     "from warnings import warn\n",
     "\n",
     "import numpy as np\n",
     "import scipy.ndimage as nd\n",
     "import h5py\n",
+    "import pasha as psh\n",
     "\n",
     "from sklearn.cluster import DBSCAN\n",
     "from extra_data import RunDirectory\n",
@@ -276,7 +278,36 @@
     "        # handle case of no identified clusters, return empty dictionary with expected keys\n",
     "        return empty_centroid_data()\n",
     "    _centroids = get_centroids(_tpx_data, timewalk_lut=centroiding_timewalk_lut)\n",
-    "    return _centroids"
+    "    return _centroids\n",
+    "\n",
+    "\n",
+    "def process_train(worker_id, index, train_id, data):\n",
+    "    events = data[in_fast_data]\n",
+    "\n",
+    "    sel = np.s_[:events['data.size']]\n",
+    "\n",
+    "    x = events['data.x'][sel]\n",
+    "    y = events['data.y'][sel]\n",
+    "    tot = events['data.tot'][sel]\n",
+    "    toa = events['data.toa'][sel]\n",
+    "\n",
+    "    if raw_timewalk_lut is not None:\n",
+    "        toa -= raw_timewalk_lut[np.int_(tot // 25) - 1] * 1e3\n",
+    "\n",
+    "    centroids = compute_centroids(x, y, toa, tot, **centroiding_kwargs)\n",
+    "\n",
+    "    num_centroids = len(centroids['x'])\n",
+    "    fraction_centroids = np.sum(centroids[\"size\"])/events['data.size'] if events['data.size']>0 else np.nan\n",
+    "    missing_centroids = num_centroids > max_num_centroids\n",
+    "\n",
+    "    if num_centroids > max_num_centroids:\n",
+    "        warn('number of centroids larger than definde maximum, some data cannot be written to disk')\n",
+    "\n",
+    "    for key in centroid_dt.names:\n",
+    "        out_data[index, :num_centroids][key] = centroids[key]\n",
+    "    out_stats[index][\"fraction_px_in_centroids\"] = fraction_centroids\n",
+    "    out_stats[index][\"N_centroids\"] = num_centroids\n",
+    "    out_stats[index][\"missing_centroids\"] = missing_centroids"
    ]
   },
   {
@@ -357,7 +388,10 @@
     "    clustering_n_jobs=clustering_n_jobs,\n",
     "    centroiding_timewalk_lut=centroiding_timewalk_lut)\n",
     "\n",
-    "print('Computing centroids and writing to file', flush=True, end='')\n",
+    "\n",
+    "psh.set_default_context('processes', num_workers=(num_workers := cpu_count() // 4))\n",
+    "    \n",
+    "print(f'Computing centroids with {num_workers} workers and writing to file', flush=True, end='')\n",
     "start = monotonic()\n",
     "\n",
     "for seq_id, seq_dc in enumerate(in_dc.split_trains(trains_per_part=out_seq_len)):\n",
@@ -371,38 +405,13 @@
     "                                 instrument_channels=[f'{out_fast_data}/data'])\n",
     "        seq_file.create_index(train_ids)\n",
     "            \n",
-    "        out_data = np.empty((len(train_ids), max_num_centroids), dtype=centroid_dt)\n",
+    "        out_data = psh.alloc(shape=(len(train_ids), max_num_centroids), dtype=centroid_dt)\n",
+    "        out_stats = psh.alloc(shape=(len(train_ids),), dtype=centroid_stats_dt)\n",
+    "        \n",
     "        out_data[:] = (np.nan, np.nan, np.nan, np.nan, np.nan, 0, -1)\n",
-    "        out_stats = np.empty((len(train_ids),), dtype=centroid_stats_dt)\n",
     "        out_stats[:] = tuple([centroid_stats_template[key][1] for key in centroid_stats_template])\n",
     "        \n",
-    "        for index, (train_id, data) in enumerate(seq_dc.trains()):\n",
-    "            events = data[in_fast_data]\n",
-    "\n",
-    "            sel = np.s_[:events['data.size']]\n",
-    "\n",
-    "            x = events['data.x'][sel]\n",
-    "            y = events['data.y'][sel]\n",
-    "            tot = events['data.tot'][sel]\n",
-    "            toa = events['data.toa'][sel]\n",
-    "\n",
-    "            if raw_timewalk_lut is not None:\n",
-    "                toa -= raw_timewalk_lut[np.int_(tot // 25) - 1] * 1e3\n",
-    "\n",
-    "            centroids = compute_centroids(x, y, toa, tot, **centroiding_kwargs)\n",
-    "\n",
-    "            num_centroids = len(centroids['x'])\n",
-    "            fraction_centroids = np.sum(centroids[\"size\"])/events['data.size'] if events['data.size']>0 else np.nan\n",
-    "            missing_centroids = num_centroids > max_num_centroids\n",
-    "    \n",
-    "            if num_centroids > max_num_centroids:\n",
-    "                warn('number of centroids larger than definde maximum, some data cannot be written to disk')\n",
-    "            \n",
-    "            for key in centroid_dt.names:\n",
-    "                out_data[key][index, :num_centroids] = centroids[key]\n",
-    "            out_stats[\"fraction_px_in_centroids\"][index] = fraction_centroids\n",
-    "            out_stats[\"N_centroids\"][index] = num_centroids\n",
-    "            out_stats[\"missing_centroids\"][index] = missing_centroids\n",
+    "        psh.map(process_train, seq_dc)\n",
     "        \n",
     "        # Create sources.\n",
     "        cur_slow_data = seq_file.create_control_source(out_device_id)\n",
-- 
GitLab