Skip to content
Snippets Groups Projects

[Tests] clearer comparison of HDF5 files

Merged Thomas Kluyver requested to merge test/compare-h5-files into master
1 file
+ 22
15
Compare changes
  • Side-by-side
  • Inline
import hashlib
import io
import logging
import multiprocessing
import pathlib
import tempfile
import time
from contextlib import redirect_stdout
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from subprocess import PIPE, run
@@ -22,253 +21,122 @@ from .callab_tests import automated_test_config
LOGGER = logging.getLogger(__name__)
def file_md5(
tested_file: str,
block_size: int = 2 ** 20,
) -> bytes:
"""Generating MD5 checksum for a file.
@dataclass
class ComparisonResult:
filename: str
new_dsets: list
missing_dsets: list
changed_dsets: list
Args:
tested_file: File to be tested.
block_size (_type_, optional): Block size for reading the file.
Defaults to 2**20.
"""
f = open(tested_file, "rb")
md5 = hashlib.md5()
while True:
data = f.read(block_size)
if not data:
break
md5.update(data)
f.close()
return md5.digest()
def collect_attrs(groups, datasets, objects, exclude_attrs, name, node):
"""Collect h5 attrs in groups, datasets, and objects lists."""
if node.name not in exclude_attrs:
if isinstance(node, h5py.Group):
groups.append(name)
elif isinstance(node, h5py.Dataset):
if node.dtype == 'object':
objects.append(name)
else:
datasets.append(name)
def found_differences(self):
return bool(self.new_dsets or self.missing_dsets or self.changed_dsets)
def show(self):
if not self.found_differences():
print(f"{self.filename} - ✓ no changes")
return
def compare_datasets(
file1,
file2,
datasets: list
):
"""Compare the values of datasets in two h5 files."""
h5_diff = []
for d in datasets:
try:
if not np.allclose(file1[d][()], file2[d][()], equal_nan=True):
h5_diff.append(d)
except ValueError as e:
LOGGER.error(f"ValueError: {e}, {d}")
h5_diff.append(d)
except AttributeError as e:
LOGGER.error(f"AttributeError: {e}, {d}")
h5_diff.append(d)
return h5_diff
def compare_objects(
file1,
file2,
objects: list
):
"""Compare the objects in two h5 files."""
h5_diff = []
for d in objects:
try:
if isinstance(file1[d][()], bytes):
if (
file1[d][()].decode('utf-8') != file2[d][()].decode('utf-8') # noqa
):
h5_diff.append(d)
elif (
file1[d][()].dtype != file1[d][()].dtype and
not file1[d][()] != file2[d][()]
): # pnccd files has only list of bytes
h5_diff.append(d)
except ValueError as e:
LOGGER.error(f"ValueError: {e}, {d}")
h5_diff.append(d)
except AttributeError as e:
LOGGER.error(f"AttributeError: {e}, {d}, "
f"{file1[d][()].decode('utf-8')}")
h5_diff.append(d)
return h5_diff
def find_differences(
test_file,
reference_file,
exclude_attrs,
):
"""
Find difference in groups, datasets, and objects between two h5files.
Args:
file1: first h5 file.
file2: second h5 file.
"""
print(self.filename)
for ds in self.new_dsets:
print(f" + NEW: {ds}")
for ds in self.missing_dsets:
print(f" - MISSING: {ds}")
for ds, detail in self.changed_dsets:
print(f" ~ CHANGED: {ds} ({detail})")
groups_f1 = []
datasets_f1 = []
objects_f1 = []
groups_f2 = []
datasets_f2 = []
objects_f2 = []
with h5py.File(test_file, 'r') as file1, h5py.File(reference_file, 'r') as file2: # noqa
# Fill groups, datasets, and objects list
# to compare both h5files' attrs.
file1.visititems(
partial(
collect_attrs,
groups_f1,
datasets_f1,
objects_f1,
exclude_attrs,
))
file2.visititems(
partial(
collect_attrs,
groups_f2,
datasets_f2,
objects_f2,
exclude_attrs,
))
start_time = time.perf_counter()
# Compare groups, datasets, and objects to have the same content.
assert set(groups_f1) == set(groups_f2), f"{test_file} and {reference_file} consists of different groups." # noqa
assert set(datasets_f1) == set(datasets_f2), f"{test_file} and {reference_file} consists of different datasets." # noqa
assert set(objects_f1) == set(objects_f2), f"{test_file} and {reference_file} consists of different datasets." # noqa
duration = time.perf_counter() - start_time
LOGGER.debug("Elapsed time comparing groups, "
f"datasets, and objects: {duration} seconds")
LOGGER.debug("Groups, datasets, and objects have the same content.")
# Compare datasets and objects.
start_time = time.perf_counter()
h5_diff_datasets = compare_datasets(file1, file2, datasets_f1)
duration = time.perf_counter() - start_time
LOGGER.debug(f"Elapsed time comparing datasets: {duration} seconds")
start_time = time.perf_counter()
h5_diff_objects = compare_objects(file1, file2, objects_f1)
LOGGER.debug(f"Elapsed time comparing objects: {duration} seconds")
assert not h5_diff_datasets, f"{[d for d in h5_diff_datasets]} datasets contain different values for {test_file} and {reference_file}" # noqa
LOGGER.debug("Datasets are validated.")
assert not h5_diff_objects, f"{[d for d in h5_diff_objects]} objects contain different values for {test_file} and {reference_file}" # noqa
LOGGER.debug("Objects are validated.")
def validate_files(
def gather_dsets(f: h5py.File):
res = set()
def visitor(name, obj):
if isinstance(obj, h5py.Dataset):
res.add(name)
f.visititems(visitor)
return res
def iter_sized_chunks(ds: h5py.Dataset, chunk_size: int):
"""Make slices of the dataset along the first axis
Aims for block_size bytes per block"""
if ds.ndim == 0: # Scalar
yield ()
return
chunk_l = min(chunk_size // (ds.dtype.itemsize * np.prod(ds.shape[1:])), 1)
for start in range(0, ds.shape[0], chunk_l):
yield slice(start, start + chunk_l)
def validate_file(
ref_folder: pathlib.PosixPath,
out_folder: pathlib.PosixPath,
exclude_attrs: list,
test_file: pathlib.PosixPath,
) -> Tuple[bool, pathlib.PosixPath]:
"""Validate file similarities. Create temporary files to exclude
h5 attributes known to be different. e.g `report` for constants.
If both files are not identical, the function is able to loop over
both files and find and fail on the difference.
Args:
ref_folder: The reference folder for validating the files
out_folder: The output folder for the test constant files.
test_file: The output file to be validated.
exclude_attrs: A list of datasets, groups to exclude
from validated files.
Returns:
result: validation result for metadata.
test_file: The validated file.
"""
import h5py
start_validating = time.perf_counter()
def exclude_sources(source_file, dest, excluded_sources):
# Open the source file in read-only mode
with h5py.File(source_file, 'r') as source:
# Recursively visit all objects in the source file
def visit_func(name, obj):
# Check if the object should be excluded
if name in excluded_sources:
return
# Check if the object is a dataset
if isinstance(obj, h5py.Dataset):
# Create a new dataset in the destination
# file and copy the data
dest.create_dataset(name, data=obj[()])
# Visit all objects in the source file and
# copy them to the destination file
source.visititems(visit_func)
with tempfile.NamedTemporaryFile(
dir=out_folder,
suffix=".tmp",
prefix="cal_",
delete=True,
) as out_tf, tempfile.NamedTemporaryFile(
dir=out_folder,
suffix=".tmp",
prefix="cal_",
delete=True,
) as ref_tf:
# Create in-memory HDF5 files for validation
with h5py.File(out_tf.name, 'a') as hp1, h5py.File(ref_tf.name, 'a') as hp2: # noqa
start_time = time.perf_counter()
# Copy h5 files for validation and exclude selected attrs.
exclude_sources(test_file, hp1, exclude_attrs)
duration = time.perf_counter() - start_time
LOGGER.debug(f"Elapsed time copying {test_file}: "
f"{duration} seconds")
start_time = time.perf_counter()
exclude_sources(ref_folder / test_file.name, hp2, exclude_attrs)
duration = time.perf_counter() - start_time
LOGGER.debug(f"Elapsed time copying {ref_folder / test_file.name}: "
f"{duration} seconds")
start_time = time.perf_counter()
result = file_md5(out_tf.name) == file_md5(ref_tf.name)
LOGGER.debug(f"MD5 validation for {test_file}: {duration} seconds")
duration = time.perf_counter() - start_validating
return result, test_file
exclude_dsets: set,
test_file: str,
) -> ComparisonResult:
ref_file = ref_folder / test_file
out_file = out_folder / test_file
with h5py.File(ref_file) as fref, h5py.File(out_file) as fout:
ref_dsets = gather_dsets(fref)
out_dsets = gather_dsets(fout)
changed = []
for dsname in sorted((ref_dsets & out_dsets) - exclude_dsets):
ref_ds = fref[dsname]
out_ds = fout[dsname]
if out_ds.shape != ref_ds.shape:
changed.append((
dsname, f"Shape: {ref_ds.shape} -> {out_ds.shape}"
))
elif out_ds.dtype != ref_ds.dtype:
changed.append((
dsname, f"Dtype: {ref_ds.dtype} -> {out_ds.dtype}"
))
else:
floaty = np.issubdtype(ref_ds.dtype, np.floating) \
or np.issubdtype(ref_ds.dtype, np.complexfloating)
# Compare data incrementally rather than loading it all at once;
# read in blocks of ~64 MB (arbitrary limit) along first axis.
for chunk_slice in iter_sized_chunks(ref_ds, 64 * 1024 * 1024):
ref_chunk = ref_ds[chunk_slice]
out_chunk = out_ds[chunk_slice]
if floaty:
eq = np.allclose(ref_chunk, out_chunk, equal_nan=True)
else:
eq = np.array_equal(ref_chunk, out_chunk)
if not eq:
# If just 1 entry, show the values
if ref_ds.size == 1:
r, o = np.squeeze(ref_chunk), np.squeeze(out_chunk)
changed.append((dsname, f"Value: {r} -> {o}"))
else:
changed.append((dsname, "Data changed"))
break
return ComparisonResult(
test_file,
new_dsets=sorted(out_dsets - ref_dsets),
missing_dsets=sorted(ref_dsets - out_dsets),
changed_dsets=changed,
)
def parse_config(
cmd: List[str],
config: Dict[str, Any],
out_folder: str
cmd: List[str], config: Dict[str, Any], out_folder: str
) -> List[str]:
"""Convert a dictionary to a list of arguments.
Values that are not strings will be cast.
Lists will be converted to several strings following their `--key`
flag.
Booleans will be converted to a `--key` flag, where `key` is the
dictionary key.
Values that are not strings will be cast.
Lists will be converted to several strings following their `--key`
flag.
Booleans will be converted to a `--key` flag, where `key` is the
dictionary key.
"""
for key, value in config.items():
if ' ' in key or (isinstance(value, str) and ' ' in value):
raise ValueError('Spaces are not allowed', key, value)
if " " in key or (isinstance(value, str) and " " in value):
raise ValueError("Spaces are not allowed", key, value)
if isinstance(value, list):
cmd.append(f"--{key}")
@@ -286,11 +154,9 @@ def parse_config(
def validate_hdf5_files(
test_key: str,
out_folder: pathlib.Path,
reference_folder: pathlib.Path,
cal_type: str,
find_difference: bool
):
"""Apply HDF5 data validation.
@@ -301,59 +167,48 @@ def validate_hdf5_files(
the reference data to validate against
cal_type (str): The type of calibration processing.
e.g. dark or correct.
find_difference (bool): A flag indicating a need to find the
difference between two files if tested data was
not identical to the reference data.
"""
# 3rd Check number of produced h5 files.
h5files = list(out_folder.glob("*.h5"))
expected_h5files = list(reference_folder.glob("*.h5"))
assert len(h5files) == len(expected_h5files), f"{test_key} failure, number of files are not as expected." # noqa
LOGGER.info(f"{test_key}'s calibration h5files numbers are as expected.")
print("\n--- Compare HDF5 files ----")
print("REF:", reference_folder)
print("NEW:", out_folder)
ok = True
result_h5files = {p.name for p in out_folder.glob("*.h5")}
ref_h5files = {p.name for p in reference_folder.glob("*.h5")}
missing_files = ref_h5files - result_h5files
if missing_files:
print("Files missing from result (*.h5):", ", ".join(missing_files))
ok = False
new_files = result_h5files - ref_h5files
if new_files:
print("New files in result (*.h5):", ", ".join(new_files))
ok = False
files_to_check = sorted(result_h5files & ref_h5files)
non_valid_files = []
# Hard coded datasets to exclude from numerical validation.
# These datasets are know to be updated everytime.
if cal_type.lower() == "correct":
exclude_attrs = ["METADATA/creationDate", "METADATA/updateDate"]
exclude_attrs = {"METADATA/creationDate", "METADATA/updateDate"}
else:
exclude_attrs = ["report"]
exclude_attrs = {"report"}
# 4th check that test and reference h5files are identical.
_validate_files = partial(
validate_files,
_validate_file = partial(
validate_file,
reference_folder,
out_folder,
exclude_attrs,
)
with multiprocessing.pool.ThreadPool(processes=8) as executor:
result = executor.map(_validate_files, h5files)
# Collect non-valid files, if any, to display them in the error message.
for valid, file in result:
if not valid:
non_valid_files.append(file)
if len(non_valid_files) > 0:
if find_difference:
LOGGER.error(f"Found non valid files: {non_valid_files}. "
f"Checking differences for {non_valid_files[0]}")
find_differences(
non_valid_files[0],
reference_folder / non_valid_files[0].name,
exclude_attrs
)
LOGGER.info(f"No difference found for {non_valid_files[0]}")
else:
assert len(non_valid_files) == 0, f"{test_key} failure, while validating metadata for {non_valid_files}" # noqa
LOGGER.info(f"{test_key}'s calibration h5files"
" are validated successfully.")
with multiprocessing.Pool(processes=8) as pool:
for comparison in pool.imap(_validate_file, files_to_check):
comparison.show()
if comparison.found_differences():
ok = False
return ok
def slurm_watcher(
test_key: str,
std_out: str
):
def slurm_watcher(test_key: str, std_out: str):
"""
Watch for submitted slurm jobs and wait for them to finish.
After they finish apply first test and check
@@ -380,19 +235,25 @@ def slurm_watcher(
res = run(cmd, stdout=PIPE)
states = res.stdout.decode().split("\n")[2:-1]
if not any(s.strip() in [
"COMPLETING",
"RUNNING",
"CONFIGURING",
"PENDING",
] for s in states):
if not any(
s.strip()
in [
"COMPLETING",
"RUNNING",
"CONFIGURING",
"PENDING",
]
for s in states
):
slurm_watcher = False
else:
time.sleep(2)
# 1st check that all jobs were COMPLETED without errors.
states = res.stdout.decode().split("\n")[2:-1]
assert all(s.strip() == "COMPLETED" for s in states), f"{test_key} failure, calibration jobs were not completed. {jobids}: {states}" # noqa
assert all(
s.strip() == "COMPLETED" for s in states
), f"{test_key} failure, calibration jobs were not completed. {jobids}: {states}" # noqa
LOGGER.info(f"{test_key}'s jobs were COMPLETED")
@@ -403,10 +264,9 @@ def slurm_watcher(
ids=list(automated_test_config.keys()),
)
def test_xfel_calibrate(
test_key: str, val_dict: dict,
release_test_config: Tuple[bool, bool, bool, bool]
test_key: str, val_dict: dict, release_test_config: Tuple
):
""" Test xfel calibrate detectors and calibrations written
"""Test xfel calibrate detectors and calibrations written
in the given callab_test YAML file.
Args:
test_key : Key for the xfel-calibrate test.
@@ -416,9 +276,14 @@ def test_xfel_calibrate(
"""
(
detectors, calibration, picked_test,
skip_numerical_validation, only_validate, find_difference,
use_slurm, reference_dir_base, out_dir_base,
detectors,
calibration,
picked_test,
skip_numerical_validation,
only_validate,
use_slurm,
reference_dir_base,
out_dir_base,
) = release_test_config
cal_type = val_dict["cal_type"]
@@ -426,10 +291,9 @@ def test_xfel_calibrate(
if not picked_test:
# Skip non-selected detectors
if (
detectors != ["all"] and
det_type.lower() not in [d.lower() for d in detectors]
):
if detectors != ["all"] and det_type.lower() not in [
d.lower() for d in detectors
]:
pytest.skip()
# Skip non-selected calibration
@@ -444,32 +308,35 @@ def test_xfel_calibrate(
cal_conf = val_dict["config"]
out_folder = pathlib.Path(cal_conf["out-folder"].format(
out_dir_base, cal_conf["karabo-id"], test_key))
reference_folder = pathlib.Path(val_dict["reference-folder"].format(
reference_dir_base, cal_conf["karabo-id"], test_key))
out_dir_base, cal_conf["karabo-id"], test_key
))
reference_folder = pathlib.Path(
val_dict["reference-folder"].format(
reference_dir_base, cal_conf["karabo-id"], test_key
)
)
report_name = (
out_folder /
f"{test_key}_{datetime.now().strftime('%y%m%d_%H%M%S')}")
report_name = out_folder / f"{test_key}_{datetime.now():%y%m%d_%H%M%S}"
cal_conf["report-to"] = str(report_name)
cmd = parse_config(cmd, cal_conf, out_folder)
if only_validate:
validate_hdf5_files(
test_key,
out_folder,
reference_folder,
cal_type,
find_difference,
)
assert validate_hdf5_files(
out_folder, reference_folder, cal_type
), "HDF5 files changed - see details above"
return
if not use_slurm: # e.g. for Gitlab CI.
cmd += ["--no-cluster-job"]
cmd += ["--slurm-name", test_key, "--cal-db-interface", "tcp://max-exfl-cal001:8015#8045"]
cmd += [
"--slurm-name",
test_key,
"--cal-db-interface",
"tcp://max-exfl-cal001:8015#8045",
]
f = io.StringIO()
LOGGER.info(f"Submitting CL: {cmd}")
with redirect_stdout(f):
@@ -495,10 +362,6 @@ def test_xfel_calibrate(
# Stop tests at this point, if desired.
if not skip_numerical_validation:
validate_hdf5_files(
test_key,
out_folder,
reference_folder,
cal_type,
find_difference,
)
assert validate_hdf5_files(
out_folder, reference_folder, cal_type
), "HDF5 files changed - see details above"
Loading