From 10f65c4973c900ee6dea6583dedd5c08d2f25e87 Mon Sep 17 00:00:00 2001 From: Philipp Schmidt <philipp.schmidt@xfel.eu> Date: Wed, 21 Feb 2024 16:45:23 +0100 Subject: [PATCH] Throw pasha at timepix centroiding --- .../Compute_Timepix_Event_Centroids.ipynb | 71 +++++++++++-------- 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/notebooks/Timepix/Compute_Timepix_Event_Centroids.ipynb b/notebooks/Timepix/Compute_Timepix_Event_Centroids.ipynb index 8a713ab8e..50a985dfd 100755 --- a/notebooks/Timepix/Compute_Timepix_Event_Centroids.ipynb +++ b/notebooks/Timepix/Compute_Timepix_Event_Centroids.ipynb @@ -62,11 +62,13 @@ "from datetime import datetime\n", "from pathlib import Path\n", "from time import monotonic\n", + "from os import cpu_count\n", "from warnings import warn\n", "\n", "import numpy as np\n", "import scipy.ndimage as nd\n", "import h5py\n", + "import pasha as psh\n", "\n", "from sklearn.cluster import DBSCAN\n", "from extra_data import RunDirectory\n", @@ -276,7 +278,36 @@ " # handle case of no identified clusters, return empty dictionary with expected keys\n", " return empty_centroid_data()\n", " _centroids = get_centroids(_tpx_data, timewalk_lut=centroiding_timewalk_lut)\n", - " return _centroids" + " return _centroids\n", + "\n", + "\n", + "def process_train(worker_id, index, train_id, data):\n", + " events = data[in_fast_data]\n", + "\n", + " sel = np.s_[:events['data.size']]\n", + "\n", + " x = events['data.x'][sel]\n", + " y = events['data.y'][sel]\n", + " tot = events['data.tot'][sel]\n", + " toa = events['data.toa'][sel]\n", + "\n", + " if raw_timewalk_lut is not None:\n", + " toa -= raw_timewalk_lut[np.int_(tot // 25) - 1] * 1e3\n", + "\n", + " centroids = compute_centroids(x, y, toa, tot, **centroiding_kwargs)\n", + "\n", + " num_centroids = len(centroids['x'])\n", + " fraction_centroids = np.sum(centroids[\"size\"])/events['data.size'] if events['data.size']>0 else np.nan\n", + " missing_centroids = num_centroids > max_num_centroids\n", + "\n", + " if num_centroids > max_num_centroids:\n", + " warn('number of centroids larger than definde maximum, some data cannot be written to disk')\n", + "\n", + " for key in centroid_dt.names:\n", + " out_data[index, :num_centroids][key] = centroids[key]\n", + " out_stats[index][\"fraction_px_in_centroids\"] = fraction_centroids\n", + " out_stats[index][\"N_centroids\"] = num_centroids\n", + " out_stats[index][\"missing_centroids\"] = missing_centroids" ] }, { @@ -357,7 +388,10 @@ " clustering_n_jobs=clustering_n_jobs,\n", " centroiding_timewalk_lut=centroiding_timewalk_lut)\n", "\n", - "print('Computing centroids and writing to file', flush=True, end='')\n", + "\n", + "psh.set_default_context('processes', num_workers=(num_workers := cpu_count() // 4))\n", + " \n", + "print(f'Computing centroids with {num_workers} workers and writing to file', flush=True, end='')\n", "start = monotonic()\n", "\n", "for seq_id, seq_dc in enumerate(in_dc.split_trains(trains_per_part=out_seq_len)):\n", @@ -371,38 +405,13 @@ " instrument_channels=[f'{out_fast_data}/data'])\n", " seq_file.create_index(train_ids)\n", " \n", - " out_data = np.empty((len(train_ids), max_num_centroids), dtype=centroid_dt)\n", + " out_data = psh.alloc(shape=(len(train_ids), max_num_centroids), dtype=centroid_dt)\n", + " out_stats = psh.alloc(shape=(len(train_ids),), dtype=centroid_stats_dt)\n", + " \n", " out_data[:] = (np.nan, np.nan, np.nan, np.nan, np.nan, 0, -1)\n", - " out_stats = np.empty((len(train_ids),), dtype=centroid_stats_dt)\n", " out_stats[:] = tuple([centroid_stats_template[key][1] for key in centroid_stats_template])\n", " \n", - " for index, (train_id, data) in enumerate(seq_dc.trains()):\n", - " events = data[in_fast_data]\n", - "\n", - " sel = np.s_[:events['data.size']]\n", - "\n", - " x = events['data.x'][sel]\n", - " y = events['data.y'][sel]\n", - " tot = events['data.tot'][sel]\n", - " toa = events['data.toa'][sel]\n", - "\n", - " if raw_timewalk_lut is not None:\n", - " toa -= raw_timewalk_lut[np.int_(tot // 25) - 1] * 1e3\n", - "\n", - " centroids = compute_centroids(x, y, toa, tot, **centroiding_kwargs)\n", - "\n", - " num_centroids = len(centroids['x'])\n", - " fraction_centroids = np.sum(centroids[\"size\"])/events['data.size'] if events['data.size']>0 else np.nan\n", - " missing_centroids = num_centroids > max_num_centroids\n", - " \n", - " if num_centroids > max_num_centroids:\n", - " warn('number of centroids larger than definde maximum, some data cannot be written to disk')\n", - " \n", - " for key in centroid_dt.names:\n", - " out_data[key][index, :num_centroids] = centroids[key]\n", - " out_stats[\"fraction_px_in_centroids\"][index] = fraction_centroids\n", - " out_stats[\"N_centroids\"][index] = num_centroids\n", - " out_stats[\"missing_centroids\"][index] = missing_centroids\n", + " psh.map(process_train, seq_dc)\n", " \n", " # Create sources.\n", " cur_slow_data = seq_file.create_control_source(out_device_id)\n", -- GitLab