From 0c2f1d8ae0b2e8166feba0915b6d8c72e81da231 Mon Sep 17 00:00:00 2001
From: Egor Sobolev <egor.sobolev@xfel.eu>
Date: Thu, 28 Mar 2024 20:27:27 +0100
Subject: [PATCH] Add retrieving calibraion constants from DB in
 Correct_DynamicFF_NBC.ipynb

---
 .../DynamicFF/Correct_DynamicFF_NBC.ipynb     | 262 +++++++++++-------
 1 file changed, 155 insertions(+), 107 deletions(-)

diff --git a/notebooks/DynamicFF/Correct_DynamicFF_NBC.ipynb b/notebooks/DynamicFF/Correct_DynamicFF_NBC.ipynb
index acbb79ac8..c3a875eca 100644
--- a/notebooks/DynamicFF/Correct_DynamicFF_NBC.ipynb
+++ b/notebooks/DynamicFF/Correct_DynamicFF_NBC.ipynb
@@ -18,35 +18,22 @@
    "outputs": [],
    "source": [
     "in_folder = \"/gpfs/exfel/exp/SPB/202430/p900425/raw\"  # input folder, required\n",
-    "out_folder = '/gpfs/exfel/data/scratch/esobolev/test/shimadzu'  # output folder, required\n",
+    "out_folder =\"/gpfs/exfel/exp/SPB/202430/p900425/scratch/proc/r0003\"  # output folder, required\n",
     "metadata_folder = \"\"  # Directory containing calibration_metadata.yml when run by xfel-calibrate\n",
     "run = 3  # which run to read data from, required\n",
     "\n",
     "# Data files parameters.\n",
-    "karabo_da = ['HPVX01/1', 'HPVX01/2']  # data aggregators\n",
-    "karabo_id = \"SPB_EHD_MIC\"  # karabo prefix of Shimadzu HPV-X2 devices\n",
-    "#receiver_id = \"PNCCD_FMT-0\" # inset for receiver devices\n",
-    "#path_template = 'RAW-R{:04d}-{}-S{{:05d}}.h5'  # the template to use to access data\n",
-    "instrument_source_template = 'SPB_EHD_MIC/CAM/HPVX2_{module}:daqOutput'  # data source path in h5file.\n",
-    "image_key = \"data.image.pixels\"  # image data key in Karabo or exdf notation\n",
+    "karabo_da = ['-1']  # data aggregators\n",
+    "karabo_id = \"SPB_MIC_HPVX2\"  # karabo prefix of Shimadzu HPV-X2 devices\n",
     "\n",
     "# Database access parameters.\n",
-    "use_dir_creation_date = True  # use dir creation date as data production reference date\n",
     "cal_db_interface = \"tcp://max-exfl-cal001:8021\"  # calibration DB interface to use\n",
-    "cal_db_timeout = 300000  # timeout on caldb requests\n",
-    "db_output = False  # if True, the notebook sends dark constants to the calibration database\n",
-    "local_output = True  # if True, the notebook saves dark constants locally\n",
-    "creation_time = \"\"  # To overwrite the measured creation_time. Required Format: YYYY-MM-DD HR:MN:SC.00 e.g. 2019-07-04 11:02:41.00\n",
     "\n",
+    "# Correction parameters\n",
     "n_components = 20  # number of principal components of flat-field to use in correction\n",
     "downsample_factors = [1, 1]  # list of downsample factors for each image dimention (y, x)\n",
     "\n",
-    "constants_folder = \"/gpfs/exfel/data/scratch/esobolev/test/shimadzu\"\n",
-    "db_module_template = \"Shimadzu_HPVX2_{}\"\n",
-    "\n",
-    "num_proc = 32  # number of processes running correction in parallel\n",
-    "\n",
-    "corrected_source_template = 'SPB_EHD_MIC/CORR/HPVX2_{module}:output'  # data source path in h5file."
+    "num_proc = 32  # number of processes running correction in parallel"
    ]
   },
   {
@@ -57,15 +44,26 @@
    "source": [
     "import os\n",
     "import h5py\n",
+    "import warnings\n",
+    "from logging import warning\n",
+    "\n",
+    "warnings.filterwarnings('ignore')\n",
+    "\n",
     "import numpy as np\n",
     "import matplotlib.pyplot as plt\n",
     "from IPython.display import display, Markdown\n",
+    "from datetime import datetime\n",
     "\n",
     "from extra_data import RunDirectory, by_id\n",
     "\n",
     "%matplotlib inline\n",
     "from cal_tools.step_timing import StepTimer\n",
     "from cal_tools.files import sequence_trains, DataFile\n",
+    "from cal_tools.tools import get_dir_creation_date\n",
+    "\n",
+    "from cal_tools.restful_config import calibration_client, restful_config\n",
+    "from cal_tools.calcat_interface2 import CalibrationData, setup_client\n",
+    "from cal_tools.shimadzu import ShimadzuHPVX2\n",
     "\n",
     "from dynflatfield import (\n",
     "    DynamicFlatFieldCorrectionCython as DynamicFlatFieldCorrection,\n",
@@ -80,32 +78,42 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "index_group = image_key.partition('.')[0]\n",
-    "instrument, part, component = karabo_id.split('_')\n",
+    "creation_time = get_dir_creation_date(in_folder, run)\n",
+    "print(f\"Creation time is {creation_time}\")\n",
     "\n",
-    "aggregators = {}\n",
-    "sources = {}\n",
-    "source_to_db = {}\n",
-    "print(\"Sources:\")\n",
-    "for da in karabo_da:\n",
-    "    aggr, _, module = da.partition('/')\n",
-    "    instrument_source_name = instrument_source_template.format(\n",
-    "        instrument=instrument, part=part, component=component,\n",
-    "        module=module\n",
-    "    )\n",
-    "    corrected_source_name = corrected_source_template.format(\n",
-    "        instrument=instrument, part=part, component=component,\n",
-    "        module=module\n",
-    "    )\n",
-    "    aggregators.setdefault(aggr, []).append(\n",
-    "        (instrument_source_name, corrected_source_name))\n",
-    "    sources[instrument_source_name] = aggr\n",
-    "    source_to_db[instrument_source_name] = db_module_template.format(module)\n",
-    "    print('-', instrument_source_name)\n",
-    "print()\n",
+    "cc = calibration_client()\n",
+    "pdus = cc.get_all_phy_det_units_from_detector(\n",
+    "    {\"detector_identifier\": karabo_id})\n",
+    "\n",
+    "if not pdus[\"success\"]:\n",
+    "    raise ValueException(\"Failed to retrieve PDUs\")\n",
+    "\n",
+    "detector_info = pdus['data'][0]['detector']\n",
+    "detector = ShimadzuHPVX2(detector_info[\"source_name_pattern\"])\n",
+    "index_group = detector.image_index_group\n",
+    "image_key = detector.image_key\n",
     "\n",
+    "print(f\"Instrument {detector.instrument}\")\n",
     "print(f\"Detector in use is {karabo_id}\")\n",
-    "print(f\"Instrument {instrument}\")\n",
+    "\n",
+    "modules = {}\n",
+    "for pdu in pdus[\"data\"]:\n",
+    "    db_module = pdu[\"physical_name\"]\n",
+    "    module = pdu[\"module_number\"]\n",
+    "    da = pdu[\"karabo_da\"]\n",
+    "    if karabo_da[0] != \"-1\" and da not in karabo_da:\n",
+    "        continue\n",
+    "\n",
+    "    instrument_source_name = detector.instrument_source(module)\n",
+    "    corrected_source_name = detector.corrected_source(module)\n",
+    "    print('-', da, db_module, module, instrument_source_name)\n",
+    "    \n",
+    "    modules[da] = dict(\n",
+    "        db_module=db_module,\n",
+    "        module=module,\n",
+    "        raw_source_name=instrument_source_name,\n",
+    "        corrected_source_name=corrected_source_name,\n",
+    "    )\n",
     "\n",
     "step_timer = StepTimer()"
    ]
@@ -123,44 +131,53 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "requested_conditions = {\n",
-    "    \"frame_size\": 1.0,\n",
-    "}\n",
-    "\n",
+    "# !!! REMOVE IT for production\n",
+    "# ---------------------------------------------------\n",
+    "from cal_tools.restful_config import restful_config\n",
+    "from cal_tools.calcat_interface2 import setup_client\n",
+    "\n",
+    "calcat_config = restful_config.get('calcat')\n",
+    "setup_client(  # won't be needed in production\n",
+    "    #base_url=calcat_config['base-api-url'].rpartition('/')[0],\n",
+    "    base_url='https://in.xfel.eu/test_calibration',\n",
+    "    client_id=calcat_config['user-id'],\n",
+    "    client_secret=calcat_config['user-secret'],\n",
+    "    user_email=calcat_config['user-email'],\n",
+    ")\n",
+    "caldb_root = \"/gpfs/exfel/d/cal_tst/caldb_store\"\n",
+    "creation_time = datetime.now()\n",
+    "# ===================================================\n",
     "step_timer.start()\n",
     "\n",
+    "dc = RunDirectory(f\"{in_folder}/r{run:04d}\")\n",
+    "conditions = detector.conditions(dc)\n",
+    "\n",
+    "caldata = CalibrationData.from_condition(\n",
+    "    conditions, 'SPB_MIC_HPVX2', event_at=creation_time)\n",
+    "\n",
+    "aggregators = {}\n",
     "corrections = {}\n",
-    "constant_types = [\"Offset\", \"DynamicFF\"]\n",
-    "for source, db_module in source_to_db.items():\n",
-    "    constants = {}\n",
-    "    for constant_name in constant_types:\n",
-    "        const_file = f\"{constants_folder}/const_{constant_name}_{db_module}.h5\"\n",
-    "        if not os.path.isfile(const_file):\n",
-    "            raise FileNotFoundError(f\"{constant_name} constants are not found for {karabo_id}.\")\n",
-    "\n",
-    "        with h5py.File(const_file, 'r') as f:\n",
-    "            conditions = dict(\n",
-    "                frame_size=int(f[\"condition/Frame Size/value\"][()])\n",
-    "            )\n",
-    "            data = f[\"data\"][:]\n",
-    "            data_creation_time = f[\"creation_time\"][()].decode()\n",
-    "            \n",
-    "        if not all(conditions[key] == value for key, value in requested_conditions.items()):\n",
-    "            raise ValueError(\"Conditions for {constant_name} are not match\")\n",
-    "\n",
-    "        print(f\"{source} {db_module} {constant_name}: {data_creation_time}\")\n",
-    "        constants[constant_name] = data\n",
-    "\n",
-    "    dark = constants[\"Offset\"]\n",
-    "    flat = constants[\"DynamicFF\"][0]\n",
-    "    components = constants[\"DynamicFF\"][1:][:n_components]\n",
-    "\n",
-    "    dffc = DynamicFlatFieldCorrection.from_constants(\n",
-    "        dark, flat, components, downsample_factors)\n",
-    "\n",
-    "    corrections[source] = dffc\n",
-    "\n",
-    "step_timer.done_step(\"Load calibration constants\")"
+    "for da in modules:\n",
+    "    try:\n",
+    "        # !!! REMOVE caldb_root for production\n",
+    "        dark = caldata[\"Offset\", da].ndarray(caldb_root=caldb_root)\n",
+    "        flat = caldata[\"DynamicFF\", da].ndarray(caldb_root=caldb_root)\n",
+    "        \n",
+    "        components = flat[1:][:n_components]\n",
+    "        flat = flat[0]\n",
+    "\n",
+    "        dffc = DynamicFlatFieldCorrection.from_constants(\n",
+    "            dark, flat, components, downsample_factors)\n",
+    "\n",
+    "        corrections[da] = dffc\n",
+    "        \n",
+    "        file_da, _, _ = da.partition('/')\n",
+    "        aggregators.setdefault(file_da, []).append(da)\n",
+    "    except (KeyError, FileNotFoundError):\n",
+    "        warning(f\"Constants are not found for module {da}. \"\n",
+    "                \"The module will not calibrated\")\n",
+    "\n",
+    "step_timer.done_step(\"Load calibration constants\")        "
    ]
   },
   {
@@ -176,19 +193,29 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "# Output Folder Creation:\n",
+    "os.makedirs(out_folder, exist_ok=True)\n",
+    "\n",
     "report = []\n",
-    "for aggr, sources in aggregators.items():\n",
-    "    dc = RunDirectory(f\"{in_folder}/r{run:04d}\", f\"RAW-R{run:04d}-{aggr}-S*.h5\")\n",
+    "for file_da, file_modules in aggregators.items():\n",
+    "    dc = RunDirectory(f\"{in_folder}/r{run:04d}\", f\"RAW-R{run:04d}-{file_da}-S*.h5\")\n",
     "\n",
+    "    # build train IDs\n",
     "    train_ids = set()\n",
-    "    keydata_cache = {}\n",
-    "    for instrument_source, corrected_source in sources:\n",
-    "        keydata = dc[instrument_source][image_key].drop_empty_trains()\n",
-    "        train_ids.update(keydata.train_ids)\n",
-    "        keydata_cache[instrument_source] = keydata\n",
+    "    process_modules = []\n",
+    "    for da in file_modules:\n",
+    "        instrument_source = modules[da][\"raw_source_name\"]\n",
+    "        if instrument_source in dc.all_sources:\n",
+    "            keydata = dc[instrument_source][image_key].drop_empty_trains()\n",
+    "            train_ids.update(keydata.train_ids)\n",
+    "            process_modules.append(da)\n",
+    "        else:\n",
+    "            print(f\"Source {instrument_source} for module {da} is missed\")\n",
+    "        \n",
     "    train_ids = np.array(sorted(train_ids))\n",
     "    ts = dc.select_trains(by_id[train_ids]).train_timestamps().astype(np.uint64)\n",
     "\n",
+    "    # correct and write sequence files\n",
     "    for seq_id, train_mask in sequence_trains(train_ids, 200):\n",
     "        step_timer.start()\n",
     "        print('* sequience', seq_id)\n",
@@ -198,15 +225,18 @@
     "        ntrains = len(seq_train_ids)\n",
     "\n",
     "        # create output file\n",
-    "        channels = [f\"{s[1]}/{index_group}\" for s in sources]\n",
+    "        channels = [f\"{modules[da]['corrected_source_name']}/{index_group}\"\n",
+    "                    for da in process_modules]\n",
     "\n",
-    "        f = DataFile.from_details(out_folder, aggr, run, seq_id)\n",
+    "        f = DataFile.from_details(out_folder, file_da, run, seq_id)\n",
     "        f.create_metadata(like=dc, instrument_channels=channels)\n",
     "        f.create_index(seq_train_ids, timestamps=seq_timestamps)\n",
     "\n",
+    "        # create file structure\n",
     "        seq_report = {}\n",
-    "        image_datasets = {}\n",
-    "        for instrument_source, corrected_source in sources:\n",
+    "        file_datasets = {}\n",
+    "        for da in process_modules:\n",
+    "            instrument_source = modules[da][\"raw_source_name\"]\n",
     "            keydata = dc_seq[instrument_source][image_key].drop_empty_trains()\n",
     "            count = keydata.data_counts()\n",
     "            i = np.flatnonzero(count.values)\n",
@@ -216,19 +246,31 @@
     "            shape = keydata.shape\n",
     "            count = np.in1d(seq_train_ids, keydata.train_ids).astype(int)\n",
     "\n",
+    "            corrected_source = modules[da][\"corrected_source_name\"]\n",
     "            src = f.create_instrument_source(corrected_source)\n",
     "            src.create_index(index_group=count)\n",
     "\n",
+    "            # create key for images\n",
     "            ds_data = src.create_key(image_key, shape=shape, dtype=np.float32)\n",
-    "            image_datasets[corrected_source] = ds_data\n",
+    "            module_datasets = {image_key: ds_data}\n",
+    "\n",
+    "            # create keys for image parameters\n",
+    "            for key in detector.copy_keys:\n",
+    "                keydata = dc_seq[instrument_source][key].drop_empty_trains()\n",
+    "                module_datasets[key] = (keydata, src.create_key(\n",
+    "                    key, shape=keydata.shape, dtype=keydata.dtype))\n",
+    "\n",
+    "            file_datasets[da] = module_datasets\n",
     "\n",
     "        step_timer.done_step(\"Create output file\")\n",
     "\n",
-    "        for instrument_source, corrected_source in sources:\n",
+    "        # correct and write data to file\n",
+    "        for da in process_modules:\n",
     "            step_timer.start()\n",
     "            dc_seq = dc.select_trains(by_id[seq_train_ids])\n",
     "\n",
-    "            dffc = corrections[instrument_source]\n",
+    "            dffc = corrections[da]\n",
+    "            instrument_source = modules[da][\"raw_source_name\"]\n",
     "            proc = FlatFieldCorrectionFileProcessor(dffc, num_proc, instrument_source, image_key)\n",
     "\n",
     "            proc.start_workers()\n",
@@ -237,9 +279,14 @@
     "\n",
     "            # not pulse resolved\n",
     "            corrected_images = np.stack(proc.rdr.results, 0)\n",
-    "            image_datasets[corrected_source][:] = corrected_images\n",
+    "            file_datasets[da][image_key][:] = corrected_images\n",
+    "\n",
+    "            # copy image parameters\n",
+    "            for key in detector.copy_keys:\n",
+    "                keydata, ds = file_datasets[da][key]\n",
+    "                ds[:] = keydata.ndarray()\n",
     "\n",
-    "            seq_report[instrument_source] = (raw_images[0, 0], corrected_images[:20, 0])\n",
+    "            seq_report[da] = (raw_images[0, 0], corrected_images[:20, 0])\n",
     "            step_timer.done_step(\"Correct flat-field\")\n",
     "\n",
     "        f.close()\n",
@@ -255,21 +302,22 @@
    "outputs": [],
    "source": [
     "step_timer.start()\n",
-    "\n",
-    "for source, (raw_image, corrected_images) in report[0].items():\n",
-    "    display(Markdown(f\"# {source}\"))\n",
-    "\n",
-    "    display(Markdown(\"## The first raw image\"))\n",
-    "    plot_camera_image(raw_images[0, 0])\n",
-    "    plt.show()\n",
-    "\n",
-    "    display(Markdown(\"## The first corrected image\"))\n",
-    "    plot_camera_image(corrected_images[0])\n",
-    "    plt.show()\n",
-    "\n",
-    "    display(Markdown(\"## The first corrected images in the trains (up to 20)\"))\n",
-    "    plot_images(corrected_images, figsize=(13, 8))\n",
-    "    plt.show()\n",
+    "if report:\n",
+    "    for da, (raw_image, corrected_images) in report[0].items():\n",
+    "        source = modules[da][\"raw_source_name\"]\n",
+    "        display(Markdown(f\"## {source}\"))\n",
+    "\n",
+    "        display(Markdown(\"### The first raw image\"))\n",
+    "        plot_camera_image(raw_images[0, 0])\n",
+    "        plt.show()\n",
+    "\n",
+    "        display(Markdown(\"### The first corrected image\"))\n",
+    "        plot_camera_image(corrected_images[0])\n",
+    "        plt.show()\n",
+    "\n",
+    "        display(Markdown(\"### The first corrected images in the trains (up to 20)\"))\n",
+    "        plot_images(corrected_images, figsize=(13, 8))\n",
+    "        plt.show()\n",
     "\n",
     "step_timer.done_step(\"Draw images\")"
    ]
-- 
GitLab