Skip to content
Snippets Groups Projects
Commit 57619a36 authored by Karim Ahmed's avatar Karim Ahmed
Browse files

fix: Adapt the test script for comound datatypes e.g. Timepix centroids

parent 84cdbcf3
No related branches found
No related tags found
1 merge request!1038[Test] Configure and Include new test runs
...@@ -71,6 +71,56 @@ def iter_sized_chunks(ds: h5py.Dataset, chunk_size: int): ...@@ -71,6 +71,56 @@ def iter_sized_chunks(ds: h5py.Dataset, chunk_size: int):
yield slice(start, start + chunk_l) yield slice(start, start + chunk_l)
def equal_data(ref, test, dtype):
"""Compare arrays based on their dtype."""
if (
np.issubdtype(dtype, np.floating) or
np.issubdtype(dtype, np.complexfloating)
# np.issubdtype(dtype, np.float)
):
print("HI???")
return np.allclose(ref, test, equal_nan=True)
else:
return np.array_equal(ref, test)
def handle_differences(ref_chunk, out_chunk, ref_ds, out_ds):
changes = []
dsname = ref_ds.name
if out_ds.dtype.names:
# Handle compound datasets
field_differences = []
for field in out_ds.dtype.names:
ref_field_chunk = ref_chunk[field]
out_field_chunk = out_chunk[field]
if not equal_data(
ref_field_chunk, out_field_chunk, ref_field_chunk.dtype):
field_differences.append(field)
if field_differences:
if ref_ds.size == 1:
# If just 1 entry, show the values for all differing fields
for field in field_differences:
r, o = np.squeeze(
ref_chunk[field]), np.squeeze(out_chunk[field])
changes.append((
dsname, f"Field '{field}' Value: {r} -> {o}"))
else:
changes.append((
dsname, f"Fields changed: {', '.join(field_differences)}"))
else:
if not equal_data(ref_chunk, out_chunk, ref_ds.dtype):
# If just 1 entry, show the values
if ref_ds.size == 1:
r, o = np.squeeze(ref_chunk), np.squeeze(out_chunk)
changes.append((ref_ds.name, f"Value: {r} -> {o}"))
else:
changes.append((ref_ds.name, "Data changed"))
return changes
def validate_file( def validate_file(
ref_folder: pathlib.PosixPath, ref_folder: pathlib.PosixPath,
out_folder: pathlib.PosixPath, out_folder: pathlib.PosixPath,
...@@ -95,25 +145,15 @@ def validate_file( ...@@ -95,25 +145,15 @@ def validate_file(
dsname, f"Dtype: {ref_ds.dtype} -> {out_ds.dtype}" dsname, f"Dtype: {ref_ds.dtype} -> {out_ds.dtype}"
)) ))
else: else:
floaty = np.issubdtype(ref_ds.dtype, np.floating) \
or np.issubdtype(ref_ds.dtype, np.complexfloating)
# Compare data incrementally rather than loading it all at once; # Compare data incrementally rather than loading it all at once;
# read in blocks of ~64 MB (arbitrary limit) along first axis. # read in blocks of ~64 MB (arbitrary limit) along first axis.
for chunk_slice in iter_sized_chunks(ref_ds, 64 * 1024 * 1024): for chunk_slice in iter_sized_chunks(ref_ds, 64 * 1024 * 1024):
ref_chunk = ref_ds[chunk_slice] ref_chunk = ref_ds[chunk_slice]
out_chunk = out_ds[chunk_slice] out_chunk = out_ds[chunk_slice]
if floaty: differences = handle_differences(
eq = np.allclose(ref_chunk, out_chunk, equal_nan=True) ref_chunk, out_chunk, ref_ds, out_ds)
else: changed += differences
eq = np.array_equal(ref_chunk, out_chunk) if differences:
if not eq:
# If just 1 entry, show the values
if ref_ds.size == 1:
r, o = np.squeeze(ref_chunk), np.squeeze(out_chunk)
changed.append((dsname, f"Value: {r} -> {o}"))
else:
changed.append((dsname, "Data changed"))
break break
return ComparisonResult( return ComparisonResult(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment