From 7f4185c692a8ceb652b52f3ce60daba4fdd523b5 Mon Sep 17 00:00:00 2001
From: Cyril Danilevski <cyril.danilevski@xfel.eu>
Date: Mon, 2 Aug 2021 15:55:49 +0200
Subject: [PATCH] Clean logic

---
 reportservice/report_service.py |  2 +-
 setup.py                        |  6 ++--
 src/cal_tools/agipdlib.py       | 23 ++++++------
 src/cal_tools/agipdutils.py     |  2 +-
 src/cal_tools/agipdutils_ff.py  |  3 --
 src/cal_tools/ana_tools.py      |  5 +--
 src/cal_tools/dssclib.py        | 21 +++++------
 src/cal_tools/lpdlib.py         | 13 +++----
 src/cal_tools/metrology.py      | 13 +++----
 src/cal_tools/plotting.py       |  2 +-
 src/cal_tools/tools.py          | 54 ++++++++++++++--------------
 src/xfel_calibrate/calibrate.py | 64 ++++++++++-----------------------
 src/xfel_calibrate/finalize.py  |  2 +-
 tests/test_agipdutils_ff.py     |  2 +-
 14 files changed, 84 insertions(+), 128 deletions(-)

diff --git a/reportservice/report_service.py b/reportservice/report_service.py
index 4e6859f7e..0e3da7e7c 100644
--- a/reportservice/report_service.py
+++ b/reportservice/report_service.py
@@ -64,7 +64,7 @@ async def wait_jobs(joblist):
             for job in joblist:
                 if str(job) in line:
                     found_jobs.add(job)
-        if len(found_jobs) == 0:
+        if not found_jobs:
             logging.info('Jobs are finished')
             break
         await asyncio.sleep(10)
diff --git a/setup.py b/setup.py
index 2b9a51103..e33d55fd1 100644
--- a/setup.py
+++ b/setup.py
@@ -26,10 +26,8 @@ class PreInstallCommand(build):
     def run(self):
         version = check_output(["git", "describe", "--tag"]).decode("utf8")
         version = version.replace("\n", "")
-        file = open("src/xfel_calibrate/VERSION.py", "w")
-        file.write('__version__="{}"'.format(version))
-        file.close()
-
+        with open("src/xfel_calibrate/VERSION.py", "w") as file:
+            file.write('__version__="{}"'.format(version))
         build.run(self)
 
 
diff --git a/src/cal_tools/agipdlib.py b/src/cal_tools/agipdlib.py
index 9e4a82fb2..2881c9589 100644
--- a/src/cal_tools/agipdlib.py
+++ b/src/cal_tools/agipdlib.py
@@ -1,7 +1,7 @@
+import posixpath
 import traceback
 import zlib
 from multiprocessing.pool import ThreadPool
-import posixpath
 from pathlib import Path
 from typing import Any, Dict, Optional, Tuple
 
@@ -249,8 +249,12 @@ class AgipdCorrections:
         self.h5_index_path = h5_index_path
         self.rng_pulses = max_pulses
         # avoid list(range(*[0]]))
-        self.pulses_lst = list(range(*max_pulses)) \
-            if not (len(max_pulses) == 1 and max_pulses[0] == 0) else max_pulses  # noqa
+        self.pulses_lst = (
+            list(range(*max_pulses))
+            if max_pulses != [0]
+            else max_pulses
+        )
+
         self.max_cells = max_cells
         self.gain_mode = gain_mode
         self.comp_threads = comp_threads
@@ -870,11 +874,7 @@ class AgipdCorrections:
         """
 
         # Calculate the pulse step from the chosen max_pulse range
-        if len(self.rng_pulses) == 3:
-            pulse_step = self.rng_pulses[2]
-        else:
-            pulse_step = 1
-
+        pulse_step = self.rng_pulses[2] if len(self.rng_pulses) == 3 else 1
         # Validate selected pulses range:
         # 1) Make sure the range max doesn't have non-valid idx.
         if self.pulses_lst[-1] + pulse_step > int(allpulses[-1]):
@@ -1018,10 +1018,7 @@ class AgipdCorrections:
                 if diff:
                     if i < len(cntsv):
                         cntsv = np.insert(cntsv, i, 0)
-                        if i == 0:
-                            fidxv = np.insert(fidxv, i, 0)
-                        else:
-                            fidxv = np.insert(fidxv, i, fidxv[i])
+                        fidxv = np.insert(fidxv, i, 0) if i == 0 else np.insert(fidxv, i, fidxv[i])
                     else:
                         # append if at the end of the array
                         cntsv = np.append(cntsv, 0)
@@ -1199,7 +1196,7 @@ class AgipdCorrections:
 
                 # This will handle some historical data in a different format
                 # constant dimension injected first
-                if slopesPC.shape[0] == 10 or slopesPC.shape[0] == 11:
+                if slopesPC.shape[0] in [10, 11]:
                     slopesPC = np.moveaxis(slopesPC, 0, 3)
                     slopesPC = np.moveaxis(slopesPC, 0, 2)
 
diff --git a/src/cal_tools/agipdutils.py b/src/cal_tools/agipdutils.py
index e7cbd4667..280f07324 100644
--- a/src/cal_tools/agipdutils.py
+++ b/src/cal_tools/agipdutils.py
@@ -111,7 +111,7 @@ def get_shadowed_stripe(data, threshold, fraction):
     for idx, i in enumerate(A[1:-1]):
         if i - 1 not in A:
             continue
-        if len(tmp_idx) == 0:
+        if not tmp_idx:
             tmp_idx.append(i)
             continue
         if tmp_idx[-1] + 1 == i and (
diff --git a/src/cal_tools/agipdutils_ff.py b/src/cal_tools/agipdutils_ff.py
index 6f3234ff3..ff57587d7 100644
--- a/src/cal_tools/agipdutils_ff.py
+++ b/src/cal_tools/agipdutils_ff.py
@@ -228,9 +228,6 @@ def get_mask(fit_summary: Dict[str, Any],
     d01 = fit_summary['g1mean'] - m0
 
     mask = 0
-    if not fit_summary['is_valid']:
-        mask |= BadPixelsFF.FIT_FAILED
-
     if not fit_summary['has_accurate_covar']:
         mask |= BadPixelsFF.ACCURATE_COVAR
 
diff --git a/src/cal_tools/ana_tools.py b/src/cal_tools/ana_tools.py
index edc60e75d..4680e3b42 100644
--- a/src/cal_tools/ana_tools.py
+++ b/src/cal_tools/ana_tools.py
@@ -172,10 +172,7 @@ def combine_lists(*args, names=None):
 
     if isinstance(names, (list, tuple)):
         assert len(names) == len(args)
-        d_possible_params = []
-        for par in possible_params:
-            d_possible_params.append(dict(zip(names, par)))
-        return d_possible_params
+        return [dict(zip(names, par)) for par in possible_params]
     return possible_params
 
 
diff --git a/src/cal_tools/dssclib.py b/src/cal_tools/dssclib.py
index ab7d39a6e..59ca874d2 100644
--- a/src/cal_tools/dssclib.py
+++ b/src/cal_tools/dssclib.py
@@ -69,12 +69,12 @@ def get_dssc_ctrl_data(in_folder, slow_data_pattern,
         if os.path.exists(f):
             ctrlDataFiles[quadrant + 1] = f
 
-    if len(ctrlDataFiles) == 0:
+    if not ctrlDataFiles:
         print("ERROR: no Slow Control Data found!")
         return targetGainAll, encodedGainAll, operatingFreqAll
 
     daq_format = None
-    
+
     tGain = {}
     encodedGain = {}
     operatingFreqs = {}
@@ -86,10 +86,7 @@ def get_dssc_ctrl_data(in_folder, slow_data_pattern,
                 if not daq_format:
                     tGain[quadrant] = 0.0  # 0.0 is default value for TG
 
-                    if iramp_path in h5file:
-                        irampSettings = h5file[iramp_path][0]
-                    else:
-                        irampSettings = "Various"
+                    irampSettings = h5file[iramp_path][0] if iramp_path in h5file else "Various"
                 else:
                     epcConfig = h5file[f'/RUN/{slow_data_path}{quadrant}/epcRegisterFilePath/value'][0]\
                         .decode("utf-8")
@@ -101,10 +98,14 @@ def get_dssc_ctrl_data(in_folder, slow_data_pattern,
                         targGain) if targGain is not None else 0.0
                     irampSettings = h5file[iramp_path][0].decode("utf-8")
 
-                gainSettingsMap = {}
-                for coarseParam in ['fcfEnCap', 'csaFbCap', 'csaResistor']:
-                    gainSettingsMap[coarseParam] = int(
-                        h5file[f'/RUN/{slow_data_path}{quadrant}/gain/{coarseParam}/value'][0])
+                gainSettingsMap = {
+                    coarseParam: int(
+                        h5file[
+                            f'/RUN/{slow_data_path}{quadrant}/gain/{coarseParam}/value'
+                        ][0]
+                    )
+                    for coarseParam in ['fcfEnCap', 'csaFbCap', 'csaResistor']
+                }
 
                 gainSettingsMap['trimmed'] = np.int64(
                     1) if irampSettings == "Various" else np.int64(0)
diff --git a/src/cal_tools/lpdlib.py b/src/cal_tools/lpdlib.py
index 217c769fc..7ede9b52d 100644
--- a/src/cal_tools/lpdlib.py
+++ b/src/cal_tools/lpdlib.py
@@ -1,5 +1,5 @@
 import copy
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
 
 import h5py
 import numpy as np
@@ -50,8 +50,6 @@ class LpdCorrections:
 
         :param infile: to be corrected h5py input file
         :param outfile: writeable h5py output file
-        :param max_cell: maximum number of memory cells to handle, e.g. if
-                         calibration constants only exist for a subset of cells
         :param channel: module/channel to correct
         :param max_pulses: maximum pulse id to consider for preview histograms
         :param bins_gain_vs_signal: number of bins for gain vs signal histogram
@@ -178,7 +176,8 @@ class LpdCorrections:
                 self.create_output_datasets()
             self.initialized = True
 
-    def split_gain(self, d):
+    @staticmethod
+    def split_gain(d):
         """ Split gain information off 16-bit LPD data
 
         Gain information can be found in bits 12 and 13 (0-based)
@@ -558,7 +557,7 @@ class LpdCorrections:
                  self.hists_gain_vs_signal),
                 (self.low_edges, self.high_edges, self.signal_edges))
 
-    def initialize_from_db(self, dbparms: Tuple['DBParms', 'DBParms_timeout'],
+    def initialize_from_db(self, dbparms: List[Tuple['DBParms', 'DBParms_timeout']],
                            karabo_id: str, karabo_da: str,
                            only_dark: Optional[bool] = False):
         """ Initialize calibration constants from the calibration database
@@ -750,11 +749,7 @@ class LpdCorrections:
 
         """
         offsets = None
-        rel_gains = None
-        rel_gains_b = None
-        bpixels = None
         noises = None
-        flat_fields = None
         with h5py.File(filename, "r") as calfile:
             bpixels = calfile["{}/{}/data".format(qm, "BadPixelsCI")][()]
             bpix = calfile["{}/{}/data".format(qm, "BadPixelsFF")][()]
diff --git a/src/cal_tools/metrology.py b/src/cal_tools/metrology.py
index 618888b05..18330747d 100644
--- a/src/cal_tools/metrology.py
+++ b/src/cal_tools/metrology.py
@@ -62,14 +62,14 @@ def getModulePosition(metrologyFile, moduleId):
     with h5py.File(metrologyFile, "r") as fh:
         # Check if the keys actually appear in the metrology file
         for key in h5Keys:
-            if not key in fh:
+            if key not in fh:
                 raise ValueError("Invalid key '{}'".format(key))
         # Extract the positions from the hdf5 groups corresponding
         # to a module, if the module has dataset 'Position'.
         positions = [
             np.asarray(fh[key]['Position']) for key in h5Keys if 'Position' in fh[key]
         ]
-    if len(positions) == 0:
+    if not positions:
         # This is the case when requesting a quadrant; e.g.
         # getModulePosition('Q1'). Key is valid, but quadrant
         # has no location (yet).
@@ -115,10 +115,7 @@ def translateToModuleBL(tilePositions):
     # In the clockwise order of LPD tiles, the 8th
     # tile in the list is the bottom left tile
     bottomLeft8th = np.asarray([0., moduleCoords[8][1]])
-    # Translate coordinates to the bottom left corner
-    # of the bottom left tile
-    bottomLeft = moduleCoords - bottomLeft8th
-    return bottomLeft
+    return moduleCoords - bottomLeft8th
 
 
 def plotSupermoduleData(tileData, metrologyPositions, zoom=1., vmin=100., vmax=6000.):
@@ -427,7 +424,7 @@ def positionFileList(filelist, datapath, geometry_file, quad_pos, nImages='all',
 
                         indices += list(np.arange(first, first+count))
 
-                    if len(indices) == 0:
+                    if not indices:
                         continue
                     indices = np.unique(np.sort(np.array(indices).astype(np.int)))
                     indices = indices[indices < f[datapath.format(ch)].shape[0]]
@@ -534,7 +531,7 @@ def matchedFileList(filelist, datapath, nImages='all', trainIds=None, nwa=False,
 
                         indices += list(np.arange(first, first+count))
 
-                    if len(indices) == 0:
+                    if not indices:
                         continue
                     indices = np.unique(np.sort(np.array(indices).astype(np.int)))
                     indices = indices[indices < f[datapath.format(ch)].shape[0]]
diff --git a/src/cal_tools/plotting.py b/src/cal_tools/plotting.py
index d2434ffaf..a7907f926 100644
--- a/src/cal_tools/plotting.py
+++ b/src/cal_tools/plotting.py
@@ -283,7 +283,7 @@ def show_processed_modules(dinstance: str, constants: Optional[Dict[str, Any]],
 
     # Create a dict that contains the range of tiles, in the figure,
     # that belong to a module.
-    ranges = dict()
+    ranges = {}
     tile_count = 0
     for quadrant in range(1, quadrants+1):
         for module in range(1, modules+1):
diff --git a/src/cal_tools/tools.py b/src/cal_tools/tools.py
index 9b8b8f3a7..6c743ebf2 100644
--- a/src/cal_tools/tools.py
+++ b/src/cal_tools/tools.py
@@ -141,8 +141,8 @@ def map_modules_from_files(filelist, file_inset, quadrants, modules_per_quad):
     total_file_size = 0
     module_files = {}
     mod_ids = {}
-    for quadrant in range(0, quadrants):
-        for module in range(0, modules_per_quad):
+    for quadrant in range(quadrants):
+        for module in range(modules_per_quad):
             name = "Q{}M{}".format(quadrant + 1, module + 1)
             module_files[name] = Queue()
             num = quadrant * 4 + module
@@ -201,8 +201,7 @@ def get_notebook_name():
                                     params={'token': ss.get('token', '')})
             for nn in json.loads(response.text):
                 if nn['kernel']['id'] == kernel_id:
-                    relative_path = nn['notebook']['path']
-                    return relative_path
+                    return nn['notebook']['path']
     except:
         return environ.get("CAL_NOTEBOOK_NAME", "Unknown Notebook")
 
@@ -333,26 +332,30 @@ def save_const_to_h5(db_module: str, karabo_id: str,
 
     metadata.calibration_constant_version.raw_data_location = file_loc
 
-    dpar = {}
-    for parm in metadata.detector_condition.parameters:
-        dpar[parm.name] = {'lower_deviation_value': parm.lower_deviation,
-                           'upper_deviation_value': parm.upper_deviation,
-                           'value': parm.value,
-                           'flg_logarithmic': parm.logarithmic}
+    dpar = {
+        parm.name: {
+            'lower_deviation_value': parm.lower_deviation,
+            'upper_deviation_value': parm.upper_deviation,
+            'value': parm.value,
+            'flg_logarithmic': parm.logarithmic,
+        }
+        for parm in metadata.detector_condition.parameters
+    }
 
     creation_time = metadata.calibration_constant_version.begin_at
     raw_data = metadata.calibration_constant_version.raw_data_location
     constant_name = metadata.calibration_constant.__class__.__name__
 
-    data_to_store = {}
-    data_to_store['condition'] = dpar
-    data_to_store['db_module'] = db_module
-    data_to_store['karabo_id'] = karabo_id
-    data_to_store['constant'] = constant_name
-    data_to_store['data'] = data
-    data_to_store['creation_time'] = creation_time
-    data_to_store['file_loc'] = raw_data
-    data_to_store['report'] = report
+    data_to_store = {
+        'condition': dpar,
+        'db_module': db_module,
+        'karabo_id': karabo_id,
+        'constant': constant_name,
+        'data': data,
+        'creation_time': creation_time,
+        'file_loc': raw_data,
+        'report': report,
+    }
 
     ofile = f"{out_folder}/const_{constant_name}_{db_module}.h5"
     if isfile(ofile):
@@ -710,15 +713,12 @@ def get_constant_from_db_and_time(karabo_id: str, karabo_da: str,
                           condition, empty_constant,
                           cal_db_interface, creation_time,
                           int(print_once), timeout, ntries)
-    if m:
-        if m.comm_db_success:
-            return data, m.calibration_constant_version.begin_at
-        else:
-            # retun none for injection time if communication with db failed.
-            # reasons (no constant or condition found,
-            # or network problem)
-            return data, None
+    if m and m.comm_db_success:
+        return data, m.calibration_constant_version.begin_at
     else:
+        # return None for injection time if communication with db failed.
+        # reasons (no constant or condition found,
+        # or network problem)
         return data, None
 
 
diff --git a/src/xfel_calibrate/calibrate.py b/src/xfel_calibrate/calibrate.py
index 92863a1ed..ad73d0aaa 100755
--- a/src/xfel_calibrate/calibrate.py
+++ b/src/xfel_calibrate/calibrate.py
@@ -158,10 +158,7 @@ def consolize_name(name):
 
 def deconsolize_args(args):
     """ Variable names have underscores """
-    new_args = {}
-    for k, v in args.items():
-        new_args[k.replace("-", "_")] = v
-    return new_args
+    return {k.replace("-", "_"): v for k, v in args.items()}
 
 
 def extract_title_author_version(nb):
@@ -312,7 +309,7 @@ def balance_sequences(in_folder: str, run: int, sequences: List[int],
         sequence_files.extend(in_path.glob(f"*{k_da}-S*.h5"))
 
     # Extract sequences from input files.
-    seq_nums = set([int(sf.stem[-5:]) for sf in sequence_files])
+    seq_nums = {int(sf.stem[-5:]) for sf in sequence_files}
 
     # Validate selected sequences with sequences in in_folder
     if sequences != [-1]:
@@ -467,12 +464,8 @@ def add_args_from_nb(nb, parser, cvar=None, no_required=False):
 
         default = p.value if (not required) else None
 
-        if p.type == list or p.name == cvar:
-            if p.type is list:
-                ltype = type(p.value[0])
-            else:
-                ltype = p.type
-
+        if issubclass(p.type, list) or p.name == cvar:
+            ltype = type(p.value[0]) if issubclass(p.type, list) else p.type
             range_allowed = "RANGE ALLOWED" in p.comment.upper() if p.comment else False
             pars_group.add_argument(f"--{consolize_name(p.name)}",
                                     nargs='+',
@@ -481,7 +474,7 @@ def add_args_from_nb(nb, parser, cvar=None, no_required=False):
                                     help=helpstr,
                                     required=required,
                                     action=make_intelli_list(ltype) if range_allowed else None)
-        elif p.type == bool:
+        elif issubclass(p.type, bool):
             # For a boolean, make --XYZ and --no-XYZ options.
             alt_group = pars_group.add_mutually_exclusive_group(required=required)
             alt_group.add_argument(f"--{consolize_name(p.name)}",
@@ -529,14 +522,6 @@ def extend_params(nb, extend_func_name):
     fcc["source"] += "\n" + extension
 
 
-def has_parm(parms, name):
-    """ Check if a parameter of `name` exists in parms """
-    for p in parms:
-        if p.name == name:
-            return True
-    return False
-
-
 def get_par_attr(parms, key, attr, default=None):
     """
     Return the type of parameter with name key
@@ -558,15 +543,14 @@ def flatten_list(l):
     :param l: List or a string
     :return: Same string or string with first and last entry of a list
     """
-    if isinstance(l, list):
-        if len(l) > 1:
-            return '{}-{}'.format(l[0], l[-1])
-        elif len(l) == 1:
-            return '{}'.format(l[0])
-        else:
-            return ''
-    else:
+    if not isinstance(l, list):
         return str(l)
+    if len(l) > 1:
+        return '{}-{}'.format(l[0], l[-1])
+    elif len(l) == 1:
+        return '{}'.format(l[0])
+    else:
+        return ''
 
 
 def set_figure_format(nb, enable_vector_format):
@@ -762,7 +746,7 @@ def concurrent_run(
     # first convert the notebook
     parms = extract_parameters(nb, lang='python')
 
-    if has_parm(parms, "cluster_profile"):
+    if any(p.name == "cluster_profile" for p in parms):
         cluster_profile = f"{args['cluster_profile']}_{suffix}"
     else:
         # Don't start ipcluster if there's no cluster_profile parameter
@@ -844,7 +828,7 @@ def make_par_table(parms, run_tmp_path: str):
         if len(value) > max_len[1]:
             len_parms[1] = max_len[1]
             value = split_len(value, max_len[1])
-        if p.type is str:
+        if issubclass(p.type, str):
             value = "``{}''".format(value)
         comment = tex_escape(str(p.comment)[1:])
         l_parms.append([name, value, comment])
@@ -944,7 +928,7 @@ def run():
     run_uuid = f"t{datetime.now().strftime('%y%m%d_%H%M%S')}"
 
     # check if concurrency parameter is given and we run concurrently
-    if not has_parm(parms, concurrency["parameter"]) and concurrency["parameter"] is not None:
+    if not any(p.name == "parameter" for p in parms) and concurrency["parameter"] is not None:
         msg = "Notebook cannot be run concurrently: no {} parameter".format(
             concurrency["parameter"])
         warnings.warn(msg, RuntimeWarning)
@@ -1047,21 +1031,13 @@ def run():
             if defcval is not None:
                 print(f"Concurrency parameter '{cvar}' "
                       f"is taken from notebooks.py")
-                if not isinstance(defcval, (list, tuple)):
-                    cvals = range(defcval)
-                else:
-                    cvals = defcval
-
+                cvals = defcval if isinstance(defcval, (list, tuple)) else range(defcval)
         if cvals is None:
             defcval = get_par_attr(parms, cvar, 'value')
             if defcval is not None:
                 print(f"Concurrency parameter '{cvar}' "
                       f"is taken from '{notebook}'")
-                if not isinstance(defcval, (list, tuple)):
-                    cvals = [defcval]
-                else:
-                    cvals = defcval
-
+                cvals = defcval if isinstance(defcval, (list, tuple)) else [defcval]
         if con_func:
             func = get_notebook_function(nb, con_func)
             if func is None:
@@ -1073,12 +1049,10 @@ def run():
                 f = df[con_func]
                 import inspect
                 sig = inspect.signature(f)
-                callargs = []
                 if cvals:
                     # in case default needs to be used for function call
                     args[cvar] = cvals
-                for arg in sig.parameters:
-                    callargs.append(args[arg])
+                callargs = [args[arg] for arg in sig.parameters]
                 cvals = f(*callargs)
                 print(f"Split concurrency into {cvals}")
 
@@ -1123,7 +1097,7 @@ def run():
         sequential=sequential,
     ))
 
-    if not all([j is None for j in joblist]):
+    if any(j is not None for j in joblist):
         print("Submitted the following SLURM jobs: {}".format(",".join(joblist)))
 
 
diff --git a/src/xfel_calibrate/finalize.py b/src/xfel_calibrate/finalize.py
index 851296ae8..8da388037 100644
--- a/src/xfel_calibrate/finalize.py
+++ b/src/xfel_calibrate/finalize.py
@@ -397,7 +397,7 @@ def finalize(joblist, finaljob, run_path, out_path, project, calibration,
             for job in joblist:
                 if str(job) in line:
                     found_jobs.add(job)
-        if len(found_jobs) == 0:
+        if not found_jobs:
             break
         sleep(10)
 
diff --git a/tests/test_agipdutils_ff.py b/tests/test_agipdutils_ff.py
index 19947eb35..dbc06d3ad 100644
--- a/tests/test_agipdutils_ff.py
+++ b/tests/test_agipdutils_ff.py
@@ -134,7 +134,7 @@ def test_set_par_limits():
 
     set_par_limits(parameters, peak_range, peak_norm_range, peak_width_range)
     assert parameters.keys() == expected.keys()
-    for key in parameters.keys():
+    for key in parameters:
         if isinstance(parameters[key], np.ndarray):
             assert np.all(parameters[key] == expected[key])
         else:
-- 
GitLab