Skip to content
Snippets Groups Projects
test_dssc_methods.py 12.2 KiB
Newer Older
import unittest
import logging
import os
import argparse
import joblib
import numpy as np
import xarray as xr
import toolbox_scs as tb
import toolbox_scs.detectors as tbdet
from toolbox_scs.detectors.dssc_processing import (load_chunk_data, 
from toolbox_scs.util.exceptions import ToolBoxFileError

logging.basicConfig(level=logging.DEBUG)
log_root = logging.getLogger(__name__)

suites = {"metafunctions": (
                "test_info",
                "test_calcindices",
                "test_createpulsemask",
                "test_binners",
                "test_calcindices",
                "test_createpulsemask",
                "test_binners",
                "test_calcindices",
                "test_createpulsemask",
                "test_processmodule2",
        shutil.rmtree(d, ignore_errors=True)
        log_root.info(f'remove {d}')


class TestDSSC(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        log_root.info("Start global setup")
        # ---------------------------------------------------------------------
        # ---------------------------------------------------------------------

        setup_tmp_dir()

        # ---------------------------------------------------------------------
        log_root.info("Finished global setup, start tests")

    @classmethod
    def tearDownClass(cls):
        log_root.info("Clean up test environment....")
    def test_info(self):
        proposal = 2212
        run_nr = 235
        info = tbdet.load_dssc_info(proposal, run_nr)
        self.assertEqual(info['frames_per_train'], 20)
        proposal = 2212
        run_nr = 235
        cls._dssc_info_235 = tbdet.load_dssc_info(proposal, run_nr)
        cls._run_235 = tb.load_run(proposal, run_nr, include='*DA*')

        # create 3 different binners manually
        bins_trainId_name = 'PP800_PhaseShifter'
        stepsize = .04
        bins_trainId = tb.get_array(cls._run_235,
                                    'PP800_PhaseShifter',
                                    stepsize)
        bins_pulse = ['pumped', 'unpumped'] * 10


        bin1 = tbdet.create_dssc_bins("trainId",
                                      cls._dssc_info_235['trainIds'], 
                                      bins_trainId.values)
        bin2 = tbdet.create_dssc_bins("pulse",
                                      np.linspace(0,19,20, dtype=int),
        bin3 = tbdet.create_dssc_bins("x",
                                      np.linspace(1,128,128, dtype=int),
                                      np.linspace(1,128,128, dtype=int))
        # create binner with invalid name
        with self.assertRaises(ValueError) as err:
            bin2 = tbdet.create_dssc_bins("blabla",
                                          np.linspace(0,19,20, dtype=int),
                                          bins_pulse)

        # format binner dictionary. The 4th binner is constructed automatically
        self.assertIsNotNone(binners)
        cls._binners_235 = {'trainId':bin1,'pulse':bin2}
    def test_calcindices(self):
        cls = self.__class__
        # first test including darks
        bins_pulse = ['pumped', 'unpumped', 'pumped_dark', 'unpumped_dark'] * 5
        xgm_frame_coords = tbdet.calc_xgm_frame_indices(bins_pulse)
        self.assertIsNotNone(xgm_frame_coords)
        
        # another test without darks
        cls._xgm = tbdet.load_xgm(cls._run_235)
        bins_pulse = ['pumped', 'unpumped'] * 10
        xgm_frame_coords = tbdet.calc_xgm_frame_indices(bins_pulse)
        self.assertIsNotNone(xgm_frame_coords)        
        cls._xgm['pulse'] = xgm_frame_coords

        data = np.ones([len(cls._run_235.train_ids), 
                        cls._dssc_info_235['frames_per_train']], dtype=bool)
        dimensions = ['trainId', 'pulse']
        coordinates = {'trainId': cls._run_235.train_ids,
                       'pulse': range(cls._dssc_info_235['frames_per_train'])}
        pulsemask = xr.DataArray(data, dims=dimensions, coords=coordinates)
        valid = (cls._xgm > xgm_min) * (cls._xgm < xgm_max)
        cls._pulsemask = valid.combine_first(pulsemask).astype(bool)

        self.assertIsNotNone(cls._pulsemask)


    def test_createempty(self):
        cls = self.__class__
        
        # standard dset
        empty_data = create_empty_dataset(cls._dssc_info_235, cls._binners_235)
        self.assertIsNotNone(empty_data.dims['trainId'])

        # bin along pulse dimension only
        bins_pulse = ['pumped', 'unpumped', 'pumped_dark', 'unpumped_dark'] * 5
        binner_pulse = tbdet.create_dssc_bins("pulse",
                                              np.linspace(0,19,20, dtype=int),
                                              bins_pulse)
        empty_data = create_empty_dataset(cls._dssc_info_235, 
                                          {'pulse':binner_pulse})
        self.assertEqual(empty_data.pulse.values[1], 'pumped_dark')
        module = 1
        chunksize = 512
        sourcename = f'SCS_DET_DSSC1M-1/DET/{module}CH0:xtdf'

        collection = ed.open_run(proposal, run_nr,
                                 include=f'*DSSC{module:02d}*')
        binners = cls._binners_235
        info = cls._dssc_info_235
        pulsemask = cls._pulsemask

        ntrains = len(collection.train_ids)
        chunks = np.arange(ntrains, step=chunksize)
        start_index = chunks[0]

        module_data = create_empty_dataset(info, binners)

        for chunk in chunks[0:2]:
            sel = collection.select_trains(
                                ed.by_index[chunk:chunk + chunksize])
            log_root.debug(f"Module {module}: "
                           f"loading trains {chunk}:{chunk + chunksize}")
            chunk_data = load_chunk_data(sel, sourcename)
            self.assertIsNotNone(chunk_data)

            # pulse masking, and creation of related hist subset.
            chunk_hist = xr.full_like(chunk_data[:,:,0,0], fill_value=1)
            if pulsemask is not None:
                chunk_data = chunk_data.where(pulsemask)
                chunk_hist = chunk_hist.where(pulsemask)
            chunk_data = chunk_data.to_dataset(name='data')
            chunk_data['hist'] = chunk_hist

            # apply predefined binning
            log_root.debug(f'Module {module}: apply binning to chunk_data')
            for b in binners:
                chunk_data[b+"_binned"] = binners[b]
                chunk_data = chunk_data.groupby(b+"_binned").sum(b)
                chunk_data = chunk_data.rename(name_dict={b+"_binned":b})
            # ToDo: Avoid creation of unnecessary data when binning along x,y
            log_root.debug(f'Module {module}: merge data into prepared dset')
            for var in ['data', 'hist']:
                module_data[var] = xr.concat([module_data[var],
                                             chunk_data[var]],
                                             dim='tmp').sum('tmp')

        #module_data = module_data.transpose('trainId', 'pulse', 'x', 'y')

    def test_processmodule(self):
        cls = self.__class__

        backend = 'multiprocessing'
        mod_list = [1, 15]
        n_jobs = len(mod_list)
        log_root.info(f'processing {chunksize} trains per chunk')
        log_root.info(f'using parallelization backend {backend}')

        module_jobs = []
        for m in mod_list:
            module_jobs.append(dict(
                info=cls._dssc_info_235,
                dssc_binners=cls._binners_235,
        print('start processing modules:', strftime('%X'))
        data = joblib.Parallel(n_jobs=n_jobs, backend=backend) \
            (joblib.delayed(tbdet.bin_data)(**module_jobs[i]) \
             for i in range(len(mod_list)))

        print('finished processing modules:', strftime('%X'))
        data = xr.concat(data, dim='module')
        print(data)
        self.assertIsNotNone(data)
        backend = 'multiprocessing'
        mod_list = [15, 3]
        n_jobs = len(mod_list)
        log_root.info(f'processing {chunksize} trains per chunk')
        log_root.info(f'using parallelization backend {backend}')

        info = tbdet.load_dssc_info(2212, 89)
        bin1 = tbdet.create_dssc_bins("trainId",
                                      info['trainIds'],
                                      np.ones(1691))
        binners = {'trainId': bin1}
        module_jobs = []
        for m in mod_list:
            module_jobs.append(dict(
                proposal=2212,
        print('start processing modules:', strftime('%X'))
        data = joblib.Parallel(n_jobs=n_jobs, backend=backend) \
            (joblib.delayed(tbdet.bin_data)(**module_jobs[i]) \
             for i in range(len(mod_list)))
        print('finished processing modules:', strftime('%X'))

        data = xr.concat(data, dim='module')
        data = data.squeeze()
        print(data)
        self.assertIsNotNone(data)
        tbdet.save_xarray('./tmp/run235.h5', cls._module_data)
        tbdet.save_xarray('./tmp/run235.h5', cls._module_data,
                           mode = 'w')
        run235 = tbdet.load_xarray('./tmp/run235.h5')
        #run235.close()

        with self.assertRaises(ToolBoxFileError) as cm:
            tbdet.save_xarray('./tmp/run235.h5',
                              cls._module_data.isel(pulse=0),
                              mode = 'a')
    print("\nPossible test suites:\n" + "-" * 79)
    print("-" * 79 + "\n")


def suite(*tests):
    suite = unittest.TestSuite()
    for test in tests:
    return suite


def main(*cliargs):
    try:
        for test_suite in cliargs:
            if test_suite in suites:
                runner = unittest.TextTestRunner(verbosity=2)
                runner.run(suite(*suites[test_suite]))
            else:
                log_root.warning(
                    "Unknown suite: '{}'".format(test_suite))
                pass
    except Exception as err:
        log_root.error("Unecpected error: {}".format(err),
        pass


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--list-suites',
                        action='store_true',
                        help='list possible test suites')
    parser.add_argument('--run-suites', metavar='S',
                        nargs='+', action='store',
                        help='a list of valid test suites')
    if args.list_suites:
        list_suites()

    if args.run_suites:
        main(*args.run_suites)