diff --git a/src/cal_tools/tools.py b/src/cal_tools/tools.py index 2d15866ac2517bc0925e2b6ad5bad10560f5ab36..b69502a037a9af5bc254f0a6440f6f463e4b99e6 100644 --- a/src/cal_tools/tools.py +++ b/src/cal_tools/tools.py @@ -10,6 +10,7 @@ from os import environ, listdir, path from os.path import isfile from pathlib import Path from queue import Queue +from tempfile import NamedTemporaryFile from time import sleep from typing import List, Optional, Tuple, Union from urllib.parse import urljoin @@ -811,6 +812,24 @@ def module_index_to_qm(index: int, total_modules: int = 16): return f"Q{quad+1}M{mod+1}" +def recursive_update(target: dict, source: dict): + """Recursively merge source into target, checking for conflicts + + Conflicting entries will not be copied to target. Returns True if any + conflicts were found. + """ + conflict = False + for k, v2 in source.items(): + v1 = target.get(k, None) + if isinstance(v1, dict) and isinstance(v2, dict): + conflict = recursive_update(v1, v2) or conflict + elif (v1 is not None) and (v1 != v2): + conflict = True + else: + target[k] = v2 + + return conflict + class CalibrationMetadata(dict): """Convenience class: dictionary stored in metadata YAML file @@ -847,6 +866,34 @@ class CalibrationMetadata(dict): with (copy_dir / self._yaml_fn.name).open("w") as fd: yaml.safe_dump(dict(self), fd) + def add_fragment(self, data: dict): + """Save metadata to a separate 'fragment' file to be merged later + + Avoids a risk of corrupting the main file by writing in parallel. + """ + prefix = f"metadata_frag_j{os.environ.get('SLURM_JOB_ID', '')}_" + with NamedTemporaryFile("w", dir=self._yaml_fn.parent, + prefix=prefix, suffix='.yml', delete=False) as fd: + yaml.safe_dump(data, fd) + + def gather_fragments(self): + """Merge in fragments saved by add_fragment(), then delete them""" + frag_files = list(self._yaml_fn.parent.glob('metadata_frag_*.yml')) + to_delete = [] + for fn in frag_files: + with fn.open("r") as fd: + data = yaml.safe_load(fd) + if recursive_update(self, data): + print(f"{fn} contained conflicting metadata. " + f"This file will be left for debugging") + else: + to_delete.append(fn) + + self.save() + + for fn in to_delete: + fn.unlink() + def save_constant_metadata( retrieved_constants: dict, diff --git a/src/xfel_calibrate/finalize.py b/src/xfel_calibrate/finalize.py index 0911f0afb7ac63ab9643009b33527da6e666536a..0df50980e32551a0db3f9b1bf3a8d6548856af28 100644 --- a/src/xfel_calibrate/finalize.py +++ b/src/xfel_calibrate/finalize.py @@ -379,6 +379,7 @@ def finalize(joblist, finaljob, cal_work_dir, out_path, version, title, author, if finaljob: joblist.append(str(finaljob)) metadata = cal_tools.tools.CalibrationMetadata(cal_work_dir) + metadata.gather_fragments() job_time_fmt = 'JobID,Start,End,Elapsed,Suspended,State'.split(',') job_time_summary = get_job_info(joblist, job_time_fmt) diff --git a/tests/test_cal_tools.py b/tests/test_cal_tools.py index 7343c6f5b673514c01de002964668e0f140a0f72..edf08193a951c8372a67287fc8f4cb9545a2b9da 100644 --- a/tests/test_cal_tools.py +++ b/tests/test_cal_tools.py @@ -19,6 +19,7 @@ from cal_tools.tools import ( map_seq_files, module_index_to_qm, send_to_db, + recursive_update, ) # AGIPD operating conditions. @@ -471,3 +472,15 @@ def test_module_index_to_qm(): with pytest.raises(AssertionError): module_index_to_qm(7, 5) + + +def test_recursive_update(): + tgt = {"a": {"b": 1}, "c": 2} + src = {"a": {"d": 3}, "e": 4} + assert recursive_update(tgt, src) is False + assert tgt == {"a": {"b": 1, "d": 3}, "c": 2, "e": 4} + + tgt = {"a": {"b": 1}, "c": 2} + src = {"a": {"b": 3}, "e": 4} + assert recursive_update(tgt, src) is True + assert tgt == {"a": {"b": 1}, "c": 2, "e": 4}