From 7f4185c692a8ceb652b52f3ce60daba4fdd523b5 Mon Sep 17 00:00:00 2001 From: Cyril Danilevski <cyril.danilevski@xfel.eu> Date: Mon, 2 Aug 2021 15:55:49 +0200 Subject: [PATCH] Clean logic --- reportservice/report_service.py | 2 +- setup.py | 6 ++-- src/cal_tools/agipdlib.py | 23 ++++++------ src/cal_tools/agipdutils.py | 2 +- src/cal_tools/agipdutils_ff.py | 3 -- src/cal_tools/ana_tools.py | 5 +-- src/cal_tools/dssclib.py | 21 +++++------ src/cal_tools/lpdlib.py | 13 +++---- src/cal_tools/metrology.py | 13 +++---- src/cal_tools/plotting.py | 2 +- src/cal_tools/tools.py | 54 ++++++++++++++-------------- src/xfel_calibrate/calibrate.py | 64 ++++++++++----------------------- src/xfel_calibrate/finalize.py | 2 +- tests/test_agipdutils_ff.py | 2 +- 14 files changed, 84 insertions(+), 128 deletions(-) diff --git a/reportservice/report_service.py b/reportservice/report_service.py index 4e6859f7e..0e3da7e7c 100644 --- a/reportservice/report_service.py +++ b/reportservice/report_service.py @@ -64,7 +64,7 @@ async def wait_jobs(joblist): for job in joblist: if str(job) in line: found_jobs.add(job) - if len(found_jobs) == 0: + if not found_jobs: logging.info('Jobs are finished') break await asyncio.sleep(10) diff --git a/setup.py b/setup.py index 2b9a51103..e33d55fd1 100644 --- a/setup.py +++ b/setup.py @@ -26,10 +26,8 @@ class PreInstallCommand(build): def run(self): version = check_output(["git", "describe", "--tag"]).decode("utf8") version = version.replace("\n", "") - file = open("src/xfel_calibrate/VERSION.py", "w") - file.write('__version__="{}"'.format(version)) - file.close() - + with open("src/xfel_calibrate/VERSION.py", "w") as file: + file.write('__version__="{}"'.format(version)) build.run(self) diff --git a/src/cal_tools/agipdlib.py b/src/cal_tools/agipdlib.py index 9e4a82fb2..2881c9589 100644 --- a/src/cal_tools/agipdlib.py +++ b/src/cal_tools/agipdlib.py @@ -1,7 +1,7 @@ +import posixpath import traceback import zlib from multiprocessing.pool import ThreadPool -import posixpath from pathlib import Path from typing import Any, Dict, Optional, Tuple @@ -249,8 +249,12 @@ class AgipdCorrections: self.h5_index_path = h5_index_path self.rng_pulses = max_pulses # avoid list(range(*[0]])) - self.pulses_lst = list(range(*max_pulses)) \ - if not (len(max_pulses) == 1 and max_pulses[0] == 0) else max_pulses # noqa + self.pulses_lst = ( + list(range(*max_pulses)) + if max_pulses != [0] + else max_pulses + ) + self.max_cells = max_cells self.gain_mode = gain_mode self.comp_threads = comp_threads @@ -870,11 +874,7 @@ class AgipdCorrections: """ # Calculate the pulse step from the chosen max_pulse range - if len(self.rng_pulses) == 3: - pulse_step = self.rng_pulses[2] - else: - pulse_step = 1 - + pulse_step = self.rng_pulses[2] if len(self.rng_pulses) == 3 else 1 # Validate selected pulses range: # 1) Make sure the range max doesn't have non-valid idx. if self.pulses_lst[-1] + pulse_step > int(allpulses[-1]): @@ -1018,10 +1018,7 @@ class AgipdCorrections: if diff: if i < len(cntsv): cntsv = np.insert(cntsv, i, 0) - if i == 0: - fidxv = np.insert(fidxv, i, 0) - else: - fidxv = np.insert(fidxv, i, fidxv[i]) + fidxv = np.insert(fidxv, i, 0) if i == 0 else np.insert(fidxv, i, fidxv[i]) else: # append if at the end of the array cntsv = np.append(cntsv, 0) @@ -1199,7 +1196,7 @@ class AgipdCorrections: # This will handle some historical data in a different format # constant dimension injected first - if slopesPC.shape[0] == 10 or slopesPC.shape[0] == 11: + if slopesPC.shape[0] in [10, 11]: slopesPC = np.moveaxis(slopesPC, 0, 3) slopesPC = np.moveaxis(slopesPC, 0, 2) diff --git a/src/cal_tools/agipdutils.py b/src/cal_tools/agipdutils.py index e7cbd4667..280f07324 100644 --- a/src/cal_tools/agipdutils.py +++ b/src/cal_tools/agipdutils.py @@ -111,7 +111,7 @@ def get_shadowed_stripe(data, threshold, fraction): for idx, i in enumerate(A[1:-1]): if i - 1 not in A: continue - if len(tmp_idx) == 0: + if not tmp_idx: tmp_idx.append(i) continue if tmp_idx[-1] + 1 == i and ( diff --git a/src/cal_tools/agipdutils_ff.py b/src/cal_tools/agipdutils_ff.py index 6f3234ff3..ff57587d7 100644 --- a/src/cal_tools/agipdutils_ff.py +++ b/src/cal_tools/agipdutils_ff.py @@ -228,9 +228,6 @@ def get_mask(fit_summary: Dict[str, Any], d01 = fit_summary['g1mean'] - m0 mask = 0 - if not fit_summary['is_valid']: - mask |= BadPixelsFF.FIT_FAILED - if not fit_summary['has_accurate_covar']: mask |= BadPixelsFF.ACCURATE_COVAR diff --git a/src/cal_tools/ana_tools.py b/src/cal_tools/ana_tools.py index edc60e75d..4680e3b42 100644 --- a/src/cal_tools/ana_tools.py +++ b/src/cal_tools/ana_tools.py @@ -172,10 +172,7 @@ def combine_lists(*args, names=None): if isinstance(names, (list, tuple)): assert len(names) == len(args) - d_possible_params = [] - for par in possible_params: - d_possible_params.append(dict(zip(names, par))) - return d_possible_params + return [dict(zip(names, par)) for par in possible_params] return possible_params diff --git a/src/cal_tools/dssclib.py b/src/cal_tools/dssclib.py index ab7d39a6e..59ca874d2 100644 --- a/src/cal_tools/dssclib.py +++ b/src/cal_tools/dssclib.py @@ -69,12 +69,12 @@ def get_dssc_ctrl_data(in_folder, slow_data_pattern, if os.path.exists(f): ctrlDataFiles[quadrant + 1] = f - if len(ctrlDataFiles) == 0: + if not ctrlDataFiles: print("ERROR: no Slow Control Data found!") return targetGainAll, encodedGainAll, operatingFreqAll daq_format = None - + tGain = {} encodedGain = {} operatingFreqs = {} @@ -86,10 +86,7 @@ def get_dssc_ctrl_data(in_folder, slow_data_pattern, if not daq_format: tGain[quadrant] = 0.0 # 0.0 is default value for TG - if iramp_path in h5file: - irampSettings = h5file[iramp_path][0] - else: - irampSettings = "Various" + irampSettings = h5file[iramp_path][0] if iramp_path in h5file else "Various" else: epcConfig = h5file[f'/RUN/{slow_data_path}{quadrant}/epcRegisterFilePath/value'][0]\ .decode("utf-8") @@ -101,10 +98,14 @@ def get_dssc_ctrl_data(in_folder, slow_data_pattern, targGain) if targGain is not None else 0.0 irampSettings = h5file[iramp_path][0].decode("utf-8") - gainSettingsMap = {} - for coarseParam in ['fcfEnCap', 'csaFbCap', 'csaResistor']: - gainSettingsMap[coarseParam] = int( - h5file[f'/RUN/{slow_data_path}{quadrant}/gain/{coarseParam}/value'][0]) + gainSettingsMap = { + coarseParam: int( + h5file[ + f'/RUN/{slow_data_path}{quadrant}/gain/{coarseParam}/value' + ][0] + ) + for coarseParam in ['fcfEnCap', 'csaFbCap', 'csaResistor'] + } gainSettingsMap['trimmed'] = np.int64( 1) if irampSettings == "Various" else np.int64(0) diff --git a/src/cal_tools/lpdlib.py b/src/cal_tools/lpdlib.py index 217c769fc..7ede9b52d 100644 --- a/src/cal_tools/lpdlib.py +++ b/src/cal_tools/lpdlib.py @@ -1,5 +1,5 @@ import copy -from typing import Optional, Tuple +from typing import List, Optional, Tuple import h5py import numpy as np @@ -50,8 +50,6 @@ class LpdCorrections: :param infile: to be corrected h5py input file :param outfile: writeable h5py output file - :param max_cell: maximum number of memory cells to handle, e.g. if - calibration constants only exist for a subset of cells :param channel: module/channel to correct :param max_pulses: maximum pulse id to consider for preview histograms :param bins_gain_vs_signal: number of bins for gain vs signal histogram @@ -178,7 +176,8 @@ class LpdCorrections: self.create_output_datasets() self.initialized = True - def split_gain(self, d): + @staticmethod + def split_gain(d): """ Split gain information off 16-bit LPD data Gain information can be found in bits 12 and 13 (0-based) @@ -558,7 +557,7 @@ class LpdCorrections: self.hists_gain_vs_signal), (self.low_edges, self.high_edges, self.signal_edges)) - def initialize_from_db(self, dbparms: Tuple['DBParms', 'DBParms_timeout'], + def initialize_from_db(self, dbparms: List[Tuple['DBParms', 'DBParms_timeout']], karabo_id: str, karabo_da: str, only_dark: Optional[bool] = False): """ Initialize calibration constants from the calibration database @@ -750,11 +749,7 @@ class LpdCorrections: """ offsets = None - rel_gains = None - rel_gains_b = None - bpixels = None noises = None - flat_fields = None with h5py.File(filename, "r") as calfile: bpixels = calfile["{}/{}/data".format(qm, "BadPixelsCI")][()] bpix = calfile["{}/{}/data".format(qm, "BadPixelsFF")][()] diff --git a/src/cal_tools/metrology.py b/src/cal_tools/metrology.py index 618888b05..18330747d 100644 --- a/src/cal_tools/metrology.py +++ b/src/cal_tools/metrology.py @@ -62,14 +62,14 @@ def getModulePosition(metrologyFile, moduleId): with h5py.File(metrologyFile, "r") as fh: # Check if the keys actually appear in the metrology file for key in h5Keys: - if not key in fh: + if key not in fh: raise ValueError("Invalid key '{}'".format(key)) # Extract the positions from the hdf5 groups corresponding # to a module, if the module has dataset 'Position'. positions = [ np.asarray(fh[key]['Position']) for key in h5Keys if 'Position' in fh[key] ] - if len(positions) == 0: + if not positions: # This is the case when requesting a quadrant; e.g. # getModulePosition('Q1'). Key is valid, but quadrant # has no location (yet). @@ -115,10 +115,7 @@ def translateToModuleBL(tilePositions): # In the clockwise order of LPD tiles, the 8th # tile in the list is the bottom left tile bottomLeft8th = np.asarray([0., moduleCoords[8][1]]) - # Translate coordinates to the bottom left corner - # of the bottom left tile - bottomLeft = moduleCoords - bottomLeft8th - return bottomLeft + return moduleCoords - bottomLeft8th def plotSupermoduleData(tileData, metrologyPositions, zoom=1., vmin=100., vmax=6000.): @@ -427,7 +424,7 @@ def positionFileList(filelist, datapath, geometry_file, quad_pos, nImages='all', indices += list(np.arange(first, first+count)) - if len(indices) == 0: + if not indices: continue indices = np.unique(np.sort(np.array(indices).astype(np.int))) indices = indices[indices < f[datapath.format(ch)].shape[0]] @@ -534,7 +531,7 @@ def matchedFileList(filelist, datapath, nImages='all', trainIds=None, nwa=False, indices += list(np.arange(first, first+count)) - if len(indices) == 0: + if not indices: continue indices = np.unique(np.sort(np.array(indices).astype(np.int))) indices = indices[indices < f[datapath.format(ch)].shape[0]] diff --git a/src/cal_tools/plotting.py b/src/cal_tools/plotting.py index d2434ffaf..a7907f926 100644 --- a/src/cal_tools/plotting.py +++ b/src/cal_tools/plotting.py @@ -283,7 +283,7 @@ def show_processed_modules(dinstance: str, constants: Optional[Dict[str, Any]], # Create a dict that contains the range of tiles, in the figure, # that belong to a module. - ranges = dict() + ranges = {} tile_count = 0 for quadrant in range(1, quadrants+1): for module in range(1, modules+1): diff --git a/src/cal_tools/tools.py b/src/cal_tools/tools.py index 9b8b8f3a7..6c743ebf2 100644 --- a/src/cal_tools/tools.py +++ b/src/cal_tools/tools.py @@ -141,8 +141,8 @@ def map_modules_from_files(filelist, file_inset, quadrants, modules_per_quad): total_file_size = 0 module_files = {} mod_ids = {} - for quadrant in range(0, quadrants): - for module in range(0, modules_per_quad): + for quadrant in range(quadrants): + for module in range(modules_per_quad): name = "Q{}M{}".format(quadrant + 1, module + 1) module_files[name] = Queue() num = quadrant * 4 + module @@ -201,8 +201,7 @@ def get_notebook_name(): params={'token': ss.get('token', '')}) for nn in json.loads(response.text): if nn['kernel']['id'] == kernel_id: - relative_path = nn['notebook']['path'] - return relative_path + return nn['notebook']['path'] except: return environ.get("CAL_NOTEBOOK_NAME", "Unknown Notebook") @@ -333,26 +332,30 @@ def save_const_to_h5(db_module: str, karabo_id: str, metadata.calibration_constant_version.raw_data_location = file_loc - dpar = {} - for parm in metadata.detector_condition.parameters: - dpar[parm.name] = {'lower_deviation_value': parm.lower_deviation, - 'upper_deviation_value': parm.upper_deviation, - 'value': parm.value, - 'flg_logarithmic': parm.logarithmic} + dpar = { + parm.name: { + 'lower_deviation_value': parm.lower_deviation, + 'upper_deviation_value': parm.upper_deviation, + 'value': parm.value, + 'flg_logarithmic': parm.logarithmic, + } + for parm in metadata.detector_condition.parameters + } creation_time = metadata.calibration_constant_version.begin_at raw_data = metadata.calibration_constant_version.raw_data_location constant_name = metadata.calibration_constant.__class__.__name__ - data_to_store = {} - data_to_store['condition'] = dpar - data_to_store['db_module'] = db_module - data_to_store['karabo_id'] = karabo_id - data_to_store['constant'] = constant_name - data_to_store['data'] = data - data_to_store['creation_time'] = creation_time - data_to_store['file_loc'] = raw_data - data_to_store['report'] = report + data_to_store = { + 'condition': dpar, + 'db_module': db_module, + 'karabo_id': karabo_id, + 'constant': constant_name, + 'data': data, + 'creation_time': creation_time, + 'file_loc': raw_data, + 'report': report, + } ofile = f"{out_folder}/const_{constant_name}_{db_module}.h5" if isfile(ofile): @@ -710,15 +713,12 @@ def get_constant_from_db_and_time(karabo_id: str, karabo_da: str, condition, empty_constant, cal_db_interface, creation_time, int(print_once), timeout, ntries) - if m: - if m.comm_db_success: - return data, m.calibration_constant_version.begin_at - else: - # retun none for injection time if communication with db failed. - # reasons (no constant or condition found, - # or network problem) - return data, None + if m and m.comm_db_success: + return data, m.calibration_constant_version.begin_at else: + # return None for injection time if communication with db failed. + # reasons (no constant or condition found, + # or network problem) return data, None diff --git a/src/xfel_calibrate/calibrate.py b/src/xfel_calibrate/calibrate.py index 92863a1ed..ad73d0aaa 100755 --- a/src/xfel_calibrate/calibrate.py +++ b/src/xfel_calibrate/calibrate.py @@ -158,10 +158,7 @@ def consolize_name(name): def deconsolize_args(args): """ Variable names have underscores """ - new_args = {} - for k, v in args.items(): - new_args[k.replace("-", "_")] = v - return new_args + return {k.replace("-", "_"): v for k, v in args.items()} def extract_title_author_version(nb): @@ -312,7 +309,7 @@ def balance_sequences(in_folder: str, run: int, sequences: List[int], sequence_files.extend(in_path.glob(f"*{k_da}-S*.h5")) # Extract sequences from input files. - seq_nums = set([int(sf.stem[-5:]) for sf in sequence_files]) + seq_nums = {int(sf.stem[-5:]) for sf in sequence_files} # Validate selected sequences with sequences in in_folder if sequences != [-1]: @@ -467,12 +464,8 @@ def add_args_from_nb(nb, parser, cvar=None, no_required=False): default = p.value if (not required) else None - if p.type == list or p.name == cvar: - if p.type is list: - ltype = type(p.value[0]) - else: - ltype = p.type - + if issubclass(p.type, list) or p.name == cvar: + ltype = type(p.value[0]) if issubclass(p.type, list) else p.type range_allowed = "RANGE ALLOWED" in p.comment.upper() if p.comment else False pars_group.add_argument(f"--{consolize_name(p.name)}", nargs='+', @@ -481,7 +474,7 @@ def add_args_from_nb(nb, parser, cvar=None, no_required=False): help=helpstr, required=required, action=make_intelli_list(ltype) if range_allowed else None) - elif p.type == bool: + elif issubclass(p.type, bool): # For a boolean, make --XYZ and --no-XYZ options. alt_group = pars_group.add_mutually_exclusive_group(required=required) alt_group.add_argument(f"--{consolize_name(p.name)}", @@ -529,14 +522,6 @@ def extend_params(nb, extend_func_name): fcc["source"] += "\n" + extension -def has_parm(parms, name): - """ Check if a parameter of `name` exists in parms """ - for p in parms: - if p.name == name: - return True - return False - - def get_par_attr(parms, key, attr, default=None): """ Return the type of parameter with name key @@ -558,15 +543,14 @@ def flatten_list(l): :param l: List or a string :return: Same string or string with first and last entry of a list """ - if isinstance(l, list): - if len(l) > 1: - return '{}-{}'.format(l[0], l[-1]) - elif len(l) == 1: - return '{}'.format(l[0]) - else: - return '' - else: + if not isinstance(l, list): return str(l) + if len(l) > 1: + return '{}-{}'.format(l[0], l[-1]) + elif len(l) == 1: + return '{}'.format(l[0]) + else: + return '' def set_figure_format(nb, enable_vector_format): @@ -762,7 +746,7 @@ def concurrent_run( # first convert the notebook parms = extract_parameters(nb, lang='python') - if has_parm(parms, "cluster_profile"): + if any(p.name == "cluster_profile" for p in parms): cluster_profile = f"{args['cluster_profile']}_{suffix}" else: # Don't start ipcluster if there's no cluster_profile parameter @@ -844,7 +828,7 @@ def make_par_table(parms, run_tmp_path: str): if len(value) > max_len[1]: len_parms[1] = max_len[1] value = split_len(value, max_len[1]) - if p.type is str: + if issubclass(p.type, str): value = "``{}''".format(value) comment = tex_escape(str(p.comment)[1:]) l_parms.append([name, value, comment]) @@ -944,7 +928,7 @@ def run(): run_uuid = f"t{datetime.now().strftime('%y%m%d_%H%M%S')}" # check if concurrency parameter is given and we run concurrently - if not has_parm(parms, concurrency["parameter"]) and concurrency["parameter"] is not None: + if not any(p.name == "parameter" for p in parms) and concurrency["parameter"] is not None: msg = "Notebook cannot be run concurrently: no {} parameter".format( concurrency["parameter"]) warnings.warn(msg, RuntimeWarning) @@ -1047,21 +1031,13 @@ def run(): if defcval is not None: print(f"Concurrency parameter '{cvar}' " f"is taken from notebooks.py") - if not isinstance(defcval, (list, tuple)): - cvals = range(defcval) - else: - cvals = defcval - + cvals = defcval if isinstance(defcval, (list, tuple)) else range(defcval) if cvals is None: defcval = get_par_attr(parms, cvar, 'value') if defcval is not None: print(f"Concurrency parameter '{cvar}' " f"is taken from '{notebook}'") - if not isinstance(defcval, (list, tuple)): - cvals = [defcval] - else: - cvals = defcval - + cvals = defcval if isinstance(defcval, (list, tuple)) else [defcval] if con_func: func = get_notebook_function(nb, con_func) if func is None: @@ -1073,12 +1049,10 @@ def run(): f = df[con_func] import inspect sig = inspect.signature(f) - callargs = [] if cvals: # in case default needs to be used for function call args[cvar] = cvals - for arg in sig.parameters: - callargs.append(args[arg]) + callargs = [args[arg] for arg in sig.parameters] cvals = f(*callargs) print(f"Split concurrency into {cvals}") @@ -1123,7 +1097,7 @@ def run(): sequential=sequential, )) - if not all([j is None for j in joblist]): + if any(j is not None for j in joblist): print("Submitted the following SLURM jobs: {}".format(",".join(joblist))) diff --git a/src/xfel_calibrate/finalize.py b/src/xfel_calibrate/finalize.py index 851296ae8..8da388037 100644 --- a/src/xfel_calibrate/finalize.py +++ b/src/xfel_calibrate/finalize.py @@ -397,7 +397,7 @@ def finalize(joblist, finaljob, run_path, out_path, project, calibration, for job in joblist: if str(job) in line: found_jobs.add(job) - if len(found_jobs) == 0: + if not found_jobs: break sleep(10) diff --git a/tests/test_agipdutils_ff.py b/tests/test_agipdutils_ff.py index 19947eb35..dbc06d3ad 100644 --- a/tests/test_agipdutils_ff.py +++ b/tests/test_agipdutils_ff.py @@ -134,7 +134,7 @@ def test_set_par_limits(): set_par_limits(parameters, peak_range, peak_norm_range, peak_width_range) assert parameters.keys() == expected.keys() - for key in parameters.keys(): + for key in parameters: if isinstance(parameters[key], np.ndarray): assert np.all(parameters[key] == expected[key]) else: -- GitLab