{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0393ed18-b4e5-499b-bdd3-c9e5f24b9627",
   "metadata": {},
   "source": [
    "# Timepix3\n",
    "\n",
    "Author: Björn Senfftleben / Philipp Schmidt, Version: 1.0\n",
    "\n",
    "The following notebook provides centroiding for data acquired with the Timepix3 camera detector (ASI TPX3CAM)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9484ee10",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data selection parameters.\n",
    "run = 307  # required\n",
    "in_folder = '/gpfs/exfel/exp/SQS/202430/p900421/raw'  # required\n",
    "out_folder = '/gpfs/exfel/exp/SQS/202430/p900421/scratch/cal_test'  # required\n",
    "proposal = ''  # Proposal, leave empty for auto detection based on in_folder\n",
    "\n",
    "# These parameters are required by xfel-calibrate but ignored in this notebook.\n",
    "cal_db_timeout = 0  # Calibration DB timeout, currently not used.\n",
    "cal_db_interface = 'foo'  # Calibration DB interface, currently not used.\n",
    "karabo_da = 'bar'  # Karabo data aggregator name, currently not used\n",
    "\n",
    "karabo_id = 'SQS_EXP_TIMEPIX'\n",
    "in_fast_data = '{karabo_id}/DET/TIMEPIX3:daqOutput.chip0'\n",
    "out_device_id = '{karabo_id}/CAL/TIMEPIX3'\n",
    "out_fast_data = '{karabo_id}/CAL/TIMEPIX3:daqOutput.chip0'\n",
    "out_aggregator = 'TIMEPIX01'\n",
    "out_seq_len = 2000\n",
    "\n",
    "max_num_centroids = 10000  # Maximum number of centroids per train\n",
    "chunks_centroids = [1, 5000]  # Chunking of centroid data\n",
    "dataset_compression = 'gzip'  # HDF compression method.\n",
    "dataset_compression_opts = 3  # HDF GZIP compression level.\n",
    "\n",
    "clustering_epsilon = 2.0  # centroiding: The maximum distance between two samples for one to be considered as in the neighborhood of the other\n",
    "clustering_tof_scale = 1e7  # centroiding: Scaling factor for the ToA axis so that the epsilon parameter in DB scan works in all 3 dimensions\n",
    "clustering_min_samples = 2  # centroiding: minimum number of samples necessary for a cluster\n",
    "threshold_tot = 0 # raw data: minimum ToT necessary for a pixel to contain valid data\n",
    "\n",
    "raw_timewalk_lut_filepath = ''  # fpath to look up table for timewalk correction relative to proposal path or empty string,\n",
    "centroiding_timewalk_lut_filepath = ''  # fpath to look up table for timewalk correction relative to proposal path or empty string."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "524fe654-e112-4abe-813c-a0be9b3a3034",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datetime import datetime\n",
    "from pathlib import Path\n",
    "from time import monotonic\n",
    "from os import cpu_count\n",
    "from warnings import warn\n",
    "\n",
    "import numpy as np\n",
    "import scipy.ndimage as nd\n",
    "import h5py\n",
    "import pasha as psh\n",
    "\n",
    "from sklearn.cluster import DBSCAN\n",
    "from extra_data import RunDirectory\n",
    "from extra_data.read_machinery import find_proposal\n",
    "\n",
    "from cal_tools.files import DataFile, sequence_pulses\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e36f997c-4b66-4b11-99a8-5887e3572f56",
   "metadata": {},
   "outputs": [],
   "source": [
    "# centroiding\n",
    "error_msgs = {\n",
    "    -1: \"tpx_data has an invalid structure - ignore provided data\",\n",
    "    -2: \"tpx_data arrays are of invalid lengths - ignore provided data\",\n",
    "    -3: \"tpx_data arrays are empty\"\n",
    "}\n",
    "\n",
    "\n",
    "def check_data(tpx_data):\n",
    "    required_keys = [\"x\", \"y\", \"toa\", \"tot\"]\n",
    "    for key in required_keys:\n",
    "        if key not in tpx_data.keys():\n",
    "            warn(\"tpx data must contain the keys %s, but key %s not in tpx data keys (%s)\" % (required_keys, key, list(tpx_data.keys())),\n",
    "                 category=UserWarning)\n",
    "            return -1\n",
    "\n",
    "    reference_n_samples_key = \"x\"\n",
    "    n_samples = len(tpx_data[reference_n_samples_key])\n",
    "    for key in tpx_data.keys():\n",
    "        if n_samples != len(tpx_data[key]):\n",
    "            warn(\"arrays in tpx data must be of same length ( len(tpx_data[%s])=%i!=%i=(len(tpx_data[%s]) )\" % (reference_n_samples_key, n_samples, len(tpx_data[key]), key),\n",
    "                 category=UserWarning)\n",
    "            return -2\n",
    "    if n_samples == 0:\n",
    "        warn(\"no samples were provides with tpx data\", category=UserWarning)\n",
    "        return -3\n",
    "    return 0\n",
    "\n",
    "\n",
    "def apply_single_filter(tpx_data, _filter):\n",
    "    \"\"\"\n",
    "    Simple function to apply a selecting or sorting filter to a dictionary of equally sized arrays\n",
    "    Note: at no point a copy of the dictionary is made, as they are mutable, the input array is changed in memory!\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    tpx_data: dictionary with timepix data, all arrays behind each key must be of same length\n",
    "    _filter:  1d array or list of integers or booleans or np.s_ to select or sort data like a = a[_filter]\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    tpx_data: like input tpx_data but with applied filter\n",
    "\n",
    "    \"\"\"\n",
    "    try:\n",
    "        for key in tpx_data.keys():\n",
    "            tpx_data[key] = np.array(tpx_data[key])[_filter]\n",
    "    except Exception as e:\n",
    "        print(_filter)\n",
    "        print(_filter.dtype)\n",
    "        print(_filter.shape)\n",
    "        print(tpx_data[key].shape)\n",
    "        raise e\n",
    "    return tpx_data\n",
    "\n",
    "\n",
    "\n",
    "def clustering(tpx_data, epsilon=2, tof_scale=1e7, min_samples=3, n_jobs=1):\n",
    "    \"\"\"\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    tpx_data       Dictionary with timepix data, all arrays behind each key must be of same length, now with key labels\n",
    "    epsilon        The maximum distance between two samples for one to be considered as in the neighborhood of the other. \n",
    "                   This is not a maximum bound on the distances of points within a cluster. This is the most important \n",
    "                   DBSCAN parameter to choose appropriately for your data set and distance function.\n",
    "    tof_scale      Scaling factor for the ToA data so that the epsilon parameter in DB scan works not only in the x/y \n",
    "                   axes, but also in the ToA axis. So it converts ToA in s into \"ToA pixels\" -> e.g. tof_scale=1e7 means,\n",
    "                   that 100 ns is considered comparable to 1 spatial pixel. \n",
    "    min_samples    The number of samples (or total weight) in a neighborhood for a point to be considered as a core point. \n",
    "                   This includes the point itself.\n",
    "    n_jobs         The number of parallel jobs to run. None means 1 unless in a joblib.parallel_backend context. \n",
    "                   -1 means using all processors. See Glossary for more details.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "\n",
    "    \"\"\"\n",
    "    coords = np.column_stack((tpx_data[\"x\"], tpx_data[\"y\"], tpx_data[\"toa\"]*tof_scale))\n",
    "    dist = DBSCAN(eps=epsilon, min_samples=min_samples, metric=\"euclidean\", n_jobs=n_jobs).fit(coords)\n",
    "    return dist.labels_\n",
    "\n",
    "def empty_centroid_data():\n",
    "    return {\n",
    "        \"x\": np.array([]),\n",
    "        \"y\": np.array([]),\n",
    "        \"toa\": np.array([]),\n",
    "        \"tot\": np.array([]),\n",
    "        \"tot_avg\": np.array([]),\n",
    "        \"tot_max\": np.array([]),\n",
    "        \"size\": np.array([]),\n",
    "    }\n",
    "\n",
    "def get_centroids(tpx_data, timewalk_lut=None):\n",
    "    centroid_data = empty_centroid_data()\n",
    "    cluster_labels, cluster_size = np.unique(tpx_data[\"labels\"], return_counts=True)\n",
    "\n",
    "    cluster_tot_peaks = np.array(nd.maximum_position(tpx_data[\"tot\"], labels=tpx_data[\"labels\"], index=cluster_labels)).ravel()\n",
    "    cluster_tot_integrals = nd.sum(tpx_data[\"tot\"], labels=tpx_data[\"labels\"], index=cluster_labels)\n",
    "\n",
    "    # compute centroid center through weighted average\n",
    "    centroid_data[\"x\"] = np.array(nd.sum(tpx_data[\"x\"] * tpx_data[\"tot\"], labels=tpx_data[\"labels\"], index=cluster_labels) / cluster_tot_integrals).ravel()\n",
    "    centroid_data[\"y\"] = np.array(nd.sum(tpx_data[\"y\"] * tpx_data[\"tot\"], labels=tpx_data[\"labels\"], index=cluster_labels) / cluster_tot_integrals).ravel()\n",
    "    centroid_data[\"toa\"] = np.array(nd.sum(tpx_data[\"toa\"] * tpx_data[\"tot\"], labels=tpx_data[\"labels\"], index=cluster_labels) / cluster_tot_integrals).ravel()\n",
    "\n",
    "    # intensity & size information\n",
    "    centroid_data[\"tot_avg\"] = np.array(nd.mean(tpx_data[\"tot\"], labels=tpx_data[\"labels\"], index=cluster_labels))\n",
    "    centroid_data[\"tot_max\"] = tpx_data[\"tot\"][cluster_tot_peaks]\n",
    "    centroid_data[\"tot\"] = np.array(cluster_tot_integrals)\n",
    "    centroid_data[\"size\"] = cluster_size\n",
    "\n",
    "    # train ID information\n",
    "    # ~ centroid_data[\"tid\"] = tpx_data[\"tid\"][cluster_tot_peaks]\n",
    "\n",
    "    # correct for timewalk if provided\n",
    "    if timewalk_lut is not None:\n",
    "        centroid_data[\"toa\"] -= timewalk_lut[np.int_(centroid_data[\"tot_max\"] // 25) - 1] * 1e3\n",
    "    return centroid_data\n",
    "\n",
    "\n",
    "def compute_centroids(x, y, tof, tot,\n",
    "                      threshold_tot=0,\n",
    "                      clustering_epsilon=2,\n",
    "                      clustering_tof_scale=1e7,\n",
    "                      clustering_min_samples=3,\n",
    "                      centroiding_timewalk_lut=None):\n",
    "    # format input data\n",
    "    _tpx_data = {\n",
    "        \"x\": x.astype(float),\n",
    "        \"y\": y.astype(float),\n",
    "        \"toa\": tof.astype(float),\n",
    "        \"tot\": tot.astype(float)\n",
    "    }\n",
    "\n",
    "    # ensure that valid data is available\n",
    "    data_validation = check_data(_tpx_data)\n",
    "    if data_validation < 0:\n",
    "        if data_validation in error_msgs.keys():\n",
    "            print(\"Data validation failed with message: %s\" % error_msgs[data_validation])\n",
    "        else:\n",
    "            print(\"Data validation failed: unknown reason\")\n",
    "        return np.array([]), empty_centroid_data()\n",
    "\n",
    "    # clustering (identify clusters in 2d data (x,y,tof) that belong to a single hit,\n",
    "    # each sample belonging to a cluster is labeled with an integer cluster id no)\n",
    "    if threshold_tot > 0:\n",
    "        _tpx_data = apply_single_filter(_tpx_data, _tpx_data[\"tot\"] >= threshold_tot)    \n",
    "\n",
    "    labels = clustering(_tpx_data, epsilon=clustering_epsilon, tof_scale=clustering_tof_scale, min_samples=clustering_min_samples)\n",
    "    _tpx_data[\"labels\"] = labels\n",
    "    \n",
    "    if labels is not None:\n",
    "        _tpx_data = apply_single_filter(_tpx_data, labels >= 0)\n",
    "    \n",
    "    # compute centroid data (reduce cluster of samples to a single point with properties)\n",
    "    if labels is None or len(_tpx_data['x']) == 0:\n",
    "        # handle case of no identified clusters, return empty dictionary with expected keys\n",
    "        return np.array([]), empty_centroid_data()\n",
    "    return labels, get_centroids(_tpx_data, timewalk_lut=centroiding_timewalk_lut)\n",
    "\n",
    "\n",
    "def process_train(worker_id, index, train_id, data):\n",
    "    events = data[in_fast_data]\n",
    "\n",
    "    sel = np.s_[:events['data.size']]\n",
    "\n",
    "    x = events['data.x'][sel]\n",
    "    y = events['data.y'][sel]\n",
    "    tot = events['data.tot'][sel]\n",
    "    toa = events['data.toa'][sel]\n",
    "\n",
    "    if raw_timewalk_lut is not None:\n",
    "        toa -= raw_timewalk_lut[np.int_(tot // 25) - 1] * 1e3\n",
    "\n",
    "    labels, centroids = compute_centroids(x, y, toa, tot, **centroiding_kwargs)\n",
    "\n",
    "    num_centroids = len(centroids['x'])\n",
    "    fraction_centroids = np.sum(centroids[\"size\"])/events['data.size'] if events['data.size']>0 else np.nan\n",
    "    missing_centroids = num_centroids > max_num_centroids\n",
    "\n",
    "    if num_centroids > max_num_centroids:\n",
    "        warn('Number of centroids is larger than the defined maximum, some data cannot be written to disk')\n",
    "\n",
    "    for key in centroid_dt.names:\n",
    "        out_data[index, :num_centroids][key] = centroids[key]\n",
    "    out_labels[index, :len(labels)] = labels\n",
    "    out_stats[index][\"fraction_px_in_centroids\"] = fraction_centroids\n",
    "    out_stats[index][\"N_centroids\"] = num_centroids\n",
    "    out_stats[index][\"missing_centroids\"] = missing_centroids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56306c15-513e-4e7f-9c47-c52ca61b27a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "dc = RunDirectory(Path(in_folder) / f'r{run:04d}', inc_suspect_trains=True)\n",
    "proposal = list(filter(None, in_folder.strip('/').split('/')))[-2]\n",
    "base_path=find_proposal(proposal)\n",
    "\n",
    "if raw_timewalk_lut_filepath:\n",
    "    raw_timewalk_lut_filepath_full = (Path(base_path) / Path(raw_timewalk_lut_filepath)).resolve()\n",
    "    raw_timewalk_lut = np.load(raw_timewalk_lut_filepath_full)\n",
    "else:\n",
    "    raw_timewalk_lut = None\n",
    "\n",
    "if centroiding_timewalk_lut_filepath:\n",
    "    centroiding_timewalk_lut_filepath_full = (Path(base_path) / Path(centroiding_timewalk_lut_filepath)).resolve()\n",
    "    centroiding_timewalk_lut = np.load(centroiding_timewalk_lut_filepath_full)\n",
    "else:\n",
    "    centroiding_timewalk_lut = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c28a4c8-961c-496b-80da-7fd867e5b0d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "in_fast_data = in_fast_data.format(karabo_id=karabo_id)\n",
    "out_device_id = out_device_id.format(karabo_id=karabo_id)\n",
    "out_fast_data = out_fast_data.format(karabo_id=karabo_id)\n",
    "\n",
    "Path(out_folder).mkdir(exist_ok=True, parents=True)\n",
    "\n",
    "in_dc = dc.select(in_fast_data, require_all=True)\n",
    "\n",
    "dataset_kwargs = {k[8:]: v for k, v in locals().items() if k.startswith('dataset_compression')}\n",
    "\n",
    "centroid_dt = np.dtype([('x', np.float64),\n",
    "                        ('y', np.float64),\n",
    "                        ('toa', np.float64),\n",
    "                        ('tot', np.float64),\n",
    "                        ('tot_avg', np.float64),\n",
    "                        ('tot_max', np.uint16),\n",
    "                        ('size', np.int16)])\n",
    "\n",
    "pixel_shape = in_dc[in_fast_data]['data.x'].entry_shape\n",
    "\n",
    "centroid_settings_template = {\n",
    "    'timewalk_correction.raw_applied': (np.bool, bool(raw_timewalk_lut_filepath)),\n",
    "    'timewalk_correction.raw_file': (\"S100\", str(raw_timewalk_lut_filepath)[-100:]),\n",
    "    'timewalk_correction.centroiding_applied': (np.bool, bool(centroiding_timewalk_lut_filepath)),\n",
    "    'timewalk_correction.centroiding_file': (\"S100\", str(centroiding_timewalk_lut_filepath)[-100:]),\n",
    "    'clustering.epsilon': (np.float64, float(clustering_epsilon)),\n",
    "    'clustering.tof_scale': (np.float64, float(clustering_tof_scale)),\n",
    "    'clustering.min_samples': (np.int16, int(clustering_min_samples)),\n",
    "    'threshold_tot': (np.int16, int(threshold_tot)),\n",
    "}\n",
    "\n",
    "centroid_stats_template = {\n",
    "    'N_centroids': (np.int, -1),\n",
    "    'missing_centroids': (np.bool, False),\n",
    "    'fraction_px_in_centroids': (np.float64, np.nan),\n",
    "}\n",
    "\n",
    "centroid_settings_dt = np.dtype([(key, centroid_settings_template[key][0]) for key in centroid_settings_template])\n",
    "centroid_stats_dt = np.dtype([(key, centroid_stats_template[key][0]) for key in centroid_stats_template])\n",
    "\n",
    "centroiding_kwargs = dict(\n",
    "    threshold_tot=threshold_tot,\n",
    "    clustering_epsilon=clustering_epsilon,\n",
    "    clustering_tof_scale=clustering_tof_scale,\n",
    "    clustering_min_samples=clustering_min_samples,\n",
    "    centroiding_timewalk_lut=centroiding_timewalk_lut)\n",
    "\n",
    "\n",
    "psh.set_default_context('processes', num_workers=(num_workers := cpu_count() // 4))\n",
    "    \n",
    "print(f'Computing centroids with {num_workers} workers and writing to file', flush=True, end='')\n",
    "start = monotonic()\n",
    "\n",
    "for seq_id, seq_dc in enumerate(in_dc.split_trains(trains_per_part=out_seq_len)):\n",
    "    train_ids = seq_dc.train_ids\n",
    "    m_data_sources = []\n",
    "    \n",
    "    with DataFile.from_details(out_folder, out_aggregator, run, seq_id) as seq_file:                                                                                                    \n",
    "        # No support needed for old EXDF files.\n",
    "        seq_file.create_metadata(like=in_dc, sequence=seq_id,\n",
    "                                 control_sources=[out_device_id],\n",
    "                                 instrument_channels=[f'{out_fast_data}/data'])\n",
    "        seq_file.create_index(train_ids)\n",
    "        \n",
    "        out_labels = psh.alloc(shape=(len(train_ids),) + pixel_shape, dtype=np.int32)\n",
    "        out_data = psh.alloc(shape=(len(train_ids), max_num_centroids), dtype=centroid_dt)\n",
    "        out_stats = psh.alloc(shape=(len(train_ids),), dtype=centroid_stats_dt)\n",
    "        \n",
    "        out_labels[:] = -1\n",
    "        out_data[:] = (np.nan, np.nan, np.nan, np.nan, np.nan, 0, -1)\n",
    "        out_stats[:] = tuple([centroid_stats_template[key][1] for key in centroid_stats_template])\n",
    "        \n",
    "        psh.map(process_train, seq_dc)\n",
    "        \n",
    "        # Create sources.\n",
    "        cur_slow_data = seq_file.create_control_source(out_device_id)\n",
    "        cur_fast_data = seq_file.create_instrument_source(out_fast_data)\n",
    "\n",
    "        # Add source indices.\n",
    "        cur_slow_data.create_index(len(train_ids))\n",
    "        cur_fast_data.create_index(data=np.ones_like(train_ids))\n",
    "        \n",
    "        for key, (type_, data) in centroid_settings_template.items():\n",
    "            cur_slow_data.create_run_key(f'settings.{key}', data)\n",
    "        \n",
    "        cur_fast_data.create_key('data.labels', data=out_labels,\n",
    "                                 chunks=(1,) + pixel_shape, **dataset_kwargs)\n",
    "        cur_fast_data.create_key('data.centroids', out_data,\n",
    "                                 chunks=tuple(chunks_centroids),\n",
    "                                 **dataset_kwargs)\n",
    "        cur_fast_data.create_key('data.stats', out_stats)\n",
    "        \n",
    "    print('.', flush=True, end='')\n",
    "    \n",
    "end = monotonic()\n",
    "print('')\n",
    "\n",
    "print(f'{end-start:.01f}s')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pycal",
   "language": "python",
   "name": "pycal"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}