diff --git a/src/toolbox_scs/detectors/dssc.py b/src/toolbox_scs/detectors/dssc.py index c40968474aa4f859ec16b5f6cb20fbe108302c3f..e6de1aa718b948dcc79c8afe0422d3227f94e709 100644 --- a/src/toolbox_scs/detectors/dssc.py +++ b/src/toolbox_scs/detectors/dssc.py @@ -46,6 +46,7 @@ def load_dssc_info(proposal, run_nr): module = _open_run(proposal, run_nr, include='*DSSC00*') info = module.detector_info('SCS_DET_DSSC1M-1/DET/0CH0:xtdf') + log.debug("Fetched information for DSSC module nr. 0.") return info diff --git a/src/toolbox_scs/load.py b/src/toolbox_scs/load.py index b6c8438d727c19f84c030ceba75a72c4d133d3c3..ff8604261d0852c11b516398bbddaa353f629b09 100644 --- a/src/toolbox_scs/load.py +++ b/src/toolbox_scs/load.py @@ -205,7 +205,7 @@ def load_scan_variable(run, mnemonic, stepsize=None): Loads the given scan variable and rounds scan positions to integer multiples of stepsize for consistent grouping (except for stepsize=None). - Returns a dummy scan if scan_variable is set to None. + Returns a dummy scan if mnemonic is set to None. Parameters ---------- @@ -234,11 +234,17 @@ def load_scan_variable(run, mnemonic, stepsize=None): """ try: - if mnemonic not in _mnemonics_ld: + if mnemonic is None: + data = xr.DataArray( + np.ones(len(run.train_ids), dtype=np.int16), + dims=['trainId'], coords={'trainId': run.train_ids}) + elif mnemonic in _mnemonics_ld: + mnem = _mnemonics_ld[mnemonic] + data = run.get_array(mnem['source'], + mnem['key'], mnem['dim']) + else: raise ToolBoxValueError("Invalid mnemonic given", mnemonic) - mnem = _mnemonics_ld[mnemonic] - data = run.get_array(mnem['source'], - mnem['key'], mnem['dim']) + if stepsize is not None: data = stepsize * np.round(data / stepsize) data.name = 'scan_variable' diff --git a/src/toolbox_scs/test/test_detectors_common.py b/src/toolbox_scs/test/test_detectors_common.py index 2c38198e38f82916cf420ce73a3012af4f93fdcb..4a139505d75957078c98a822b6f3ba5ca7db42e7 100644 --- a/src/toolbox_scs/test/test_detectors_common.py +++ b/src/toolbox_scs/test/test_detectors_common.py @@ -31,14 +31,23 @@ def list_suites(): class TestDetectors(unittest.TestCase): - def setUp(self): - self.run = tb.run_by_proposal(2212, 235) + @classmethod + def setUpClass(cls): + log_root.info("Start global setup.") + cls._run = tb.run_by_proposal(2212, 235) fields = ["sase1", "sase3", "npulses_sase3", - "npulses_sase1", "MCP2apd", "SCS_SA3", "nrj"] - self.tb_data = tb.load(fields, 235, 2212) + "npulses_sase1", "MCP2apd", "SCS_SA3", "nrj"] + cls._tb_data = tb.load(fields, 235, 2212) + + log_root.info("Finished global setup, start tests.") + + @classmethod + def tearDownClass(cls): + pass - log_root.info("Finished setup, start tests.") + def setUp(self): + pass def tearDown(self): pass @@ -47,19 +56,23 @@ class TestDetectors(unittest.TestCase): self.assertEqual(tbdet.__name__, "toolbox_scs.detectors") def test_loadxgm(self): - xgm_data = tbdet.load_xgm(self.run) + cls = self.__class__ + xgm_data = tbdet.load_xgm(cls._run) self.assertTrue(xgm_data.values[0][-1]) def test_cleanxgm(self): - data = tbdet.cleanXGMdata(self.tb_data) + cls = self.__class__ + data = tbdet.cleanXGMdata(cls._tb_data) self.assertEqual(data['sa3_pId'].values[-1], 19) def test_matchxgmtim(self): - data = tbdet.matchXgmTimPulseId(self.tb_data) + cls = self.__class__ + data = tbdet.matchXgmTimPulseId(cls._tb_data) self.assertEqual(data['npulses_sase3'].values[0], 20) def test_loadtim(self): - data = tbdet.load_TIM(self.run) + cls = self.__class__ + data = tbdet.load_TIM(cls._run) self.assertEqual(data.name, 'MCP2apd') diff --git a/src/toolbox_scs/test/test_detectors_dssc.py b/src/toolbox_scs/test/test_detectors_dssc.py index 05c2e14d42a6d4b6e485907761cae66881e4a0ec..df974b6bd261aebac854946591eedc622b3dac0d 100644 --- a/src/toolbox_scs/test/test_detectors_dssc.py +++ b/src/toolbox_scs/test/test_detectors_dssc.py @@ -3,6 +3,7 @@ import logging import os import sys import argparse +import shutil import toolbox_scs as tb @@ -12,7 +13,21 @@ from toolbox_scs.util.exceptions import * logging.basicConfig(level=logging.DEBUG) log_root = logging.getLogger(__name__) -suites = {"dssc-preparations": ( + +# --------------------------------------------------------------------------------- +# global test settings (based on https://github.com/dscran/dssc_process/blob/master +# /example_image_process_pulsemask.ipynb) +# --------------------------------------------------------------------------------- +proposal = 2212 +run_nr = 235 +is_dark = False +framepattern = ['pumped', 'unpumped'] +maxframes = None +stepsize = .03 +scan_variable = 'PP800_PhaseShifter' +# --------------------------------------------------------------------------------- + +suites = {"preparation": ( "test_info", "test_prepareempty", ) @@ -26,17 +41,43 @@ def list_suites(): print("-------------------------\n") -class TestDetectors(unittest.TestCase): - def setUp(self): - pass +def setup_tmp_dir(): + for d in ['tmp', 'images', 'processed_runs']: + if not os.path.isdir(d): + os.mkdir(d) + +def cleanup_tmp_dir(): + for d in ['tmp', 'images', 'processed_runs']: + shutil.rmtree(d, ignore_errors=True) + + +class TestDSSC(unittest.TestCase): + @classmethod + def setUpClass(cls): + log_root.info("Start global setup") + + setup_tmp_dir() + + + cls._run = tb.run_by_proposal(proposal, run_nr, include='*DA*') + cls._scanfile = './tmp/scan.h5' + cls._maskfile = './tmp/mask.h5' + + cls._scan_variable = tb.load_scan_variable( + cls._run, scan_variable, stepsize) + #cls._scan_variable.to_netcdf(cls._scanfile, group='data', mode='w') + + log_root.info("Finished global setup, start tests") + + @classmethod + def tearDownClass(cls): + cleanup_tmp_dir() - def tearDown(self): - pass def test_info(self): - info = tbdet.load_dssc_info(2212, 235) - self.assertEqual(info['total_frames'], 180180) - + info = tbdet.load_dssc_info(proposal, run_nr) + self.assertEqual(info['frames_per_train'], 20) + def test_prepareempty(self): pass @@ -45,7 +86,7 @@ class TestDetectors(unittest.TestCase): def suite(*tests): suite = unittest.TestSuite() for test in tests: - suite.addTest(TestDetectors(test)) + suite.addTest(TestDSSC(test)) return suite diff --git a/src/toolbox_scs/test/test_top_level.py b/src/toolbox_scs/test/test_top_level.py index 8b1d82d5900535f6c7c2d232a9c2afdbee7024e8..74f7829eb061f2e2921202826d649c971928ece7 100644 --- a/src/toolbox_scs/test/test_top_level.py +++ b/src/toolbox_scs/test/test_top_level.py @@ -35,16 +35,26 @@ def list_suites(): class TestToolbox(unittest.TestCase): + @classmethod + def setUpClass(cls): + log_root.info("Start global setup") + cls._mnentry = 'SCS_RR_UTC/MDL/BUNCH_DECODER' + cls._ed_run = ed.open_run(2212, 235) + log_root.info("Finished global setup, start tests") + + @classmethod + def tearDownClass(cls): + pass + def setUp(self): - self.mnentry = 'SCS_RR_UTC/MDL/BUNCH_DECODER' - self.ed_run = ed.open_run(2212, 235) - log_root.info("Finished setup, start tests") + pass def tearDown(self): pass def test_constant(self): - self.assertEqual(tb.mnemonics['sase3']['source'],self.mnentry) + cls = self.__class__ + self.assertEqual(tb.mnemonics['sase3']['source'],cls._mnentry) def test_load(self): proposalNB = 2511 @@ -55,9 +65,9 @@ class TestToolbox(unittest.TestCase): self.assertEqual(run_tb['npulses_sase3'].values[0], 42) def test_openrun(self): - self.run = tb.run_by_proposal(2212, 235) + run = tb.run_by_proposal(2212, 235) src = 'SCS_DET_DSSC1M-1/DET/0CH0:xtdf' - self.assertTrue(src in self.run.all_sources) + self.assertTrue(src in run.all_sources) def test_openrunpath(self): run = tb.run_by_path( @@ -66,15 +76,17 @@ class TestToolbox(unittest.TestCase): self.assertTrue(src in run.all_sources) def test_loadscanvariable1(self): + cls = self.__class__ mnemonic = 'PP800_PhaseShifter' - scan_variable = tb.load_scan_variable(self.ed_run, mnemonic, 0.5) + scan_variable = tb.load_scan_variable(cls._ed_run, mnemonic, 0.5) self.assertTrue = (scan_variable) def test_loadscanvariable2(self): + cls = self.__class__ mnemonic = 'blabla' scan_variable = None with self.assertRaises(ToolBoxValueError) as cm: - scan_variable = tb.load_scan_variable(self.ed_run, mnemonic, 0.5) + scan_variable = tb.load_scan_variable(cls._ed_run, mnemonic, 0.5) excp = cm.exception err_msg = "Invalid mnemonic given" self.assertEqual(excp.message, err_msg) diff --git a/src/toolbox_scs/test/test_utils.py b/src/toolbox_scs/test/test_utils.py index 451782adf3341f6ed5be7a0b7af83f6f3ec49e0f..c17344ca0ce3eb01160d96754ff965ff55314368 100644 --- a/src/toolbox_scs/test/test_utils.py +++ b/src/toolbox_scs/test/test_utils.py @@ -26,6 +26,14 @@ def list_suites(): class TestDataAccess(unittest.TestCase): + @classmethod + def setUpClass(cls): + pass + + @classmethod + def tearDownClass(cls): + pass + def setUp(self): pass