diff --git a/tests/test_reference_runs/test_pre_deployment.py b/tests/test_reference_runs/test_pre_deployment.py index dcda91b36b1c997e0c775e01d56d9180edd1ce23..41e02403ca79dcea7a5c81d47077bcbdf5686a7a 100644 --- a/tests/test_reference_runs/test_pre_deployment.py +++ b/tests/test_reference_runs/test_pre_deployment.py @@ -71,6 +71,56 @@ def iter_sized_chunks(ds: h5py.Dataset, chunk_size: int): yield slice(start, start + chunk_l) +def equal_data(ref, test, dtype): + """Compare arrays based on their dtype.""" + if ( + np.issubdtype(dtype, np.floating) or + np.issubdtype(dtype, np.complexfloating) + # np.issubdtype(dtype, np.float) + ): + print("HI???") + return np.allclose(ref, test, equal_nan=True) + else: + return np.array_equal(ref, test) + + +def handle_differences(ref_chunk, out_chunk, ref_ds, out_ds): + changes = [] + dsname = ref_ds.name + if out_ds.dtype.names: + # Handle compound datasets + field_differences = [] + for field in out_ds.dtype.names: + ref_field_chunk = ref_chunk[field] + out_field_chunk = out_chunk[field] + if not equal_data( + ref_field_chunk, out_field_chunk, ref_field_chunk.dtype): + field_differences.append(field) + + if field_differences: + if ref_ds.size == 1: + # If just 1 entry, show the values for all differing fields + for field in field_differences: + r, o = np.squeeze( + ref_chunk[field]), np.squeeze(out_chunk[field]) + changes.append(( + dsname, f"Field '{field}' Value: {r} -> {o}")) + else: + changes.append(( + dsname, f"Fields changed: {', '.join(field_differences)}")) + + else: + if not equal_data(ref_chunk, out_chunk, ref_ds.dtype): + # If just 1 entry, show the values + if ref_ds.size == 1: + r, o = np.squeeze(ref_chunk), np.squeeze(out_chunk) + changes.append((ref_ds.name, f"Value: {r} -> {o}")) + else: + changes.append((ref_ds.name, "Data changed")) + + return changes + + def validate_file( ref_folder: pathlib.PosixPath, out_folder: pathlib.PosixPath, @@ -95,25 +145,15 @@ def validate_file( dsname, f"Dtype: {ref_ds.dtype} -> {out_ds.dtype}" )) else: - floaty = np.issubdtype(ref_ds.dtype, np.floating) \ - or np.issubdtype(ref_ds.dtype, np.complexfloating) - # Compare data incrementally rather than loading it all at once; # read in blocks of ~64 MB (arbitrary limit) along first axis. for chunk_slice in iter_sized_chunks(ref_ds, 64 * 1024 * 1024): ref_chunk = ref_ds[chunk_slice] out_chunk = out_ds[chunk_slice] - if floaty: - eq = np.allclose(ref_chunk, out_chunk, equal_nan=True) - else: - eq = np.array_equal(ref_chunk, out_chunk) - if not eq: - # If just 1 entry, show the values - if ref_ds.size == 1: - r, o = np.squeeze(ref_chunk), np.squeeze(out_chunk) - changed.append((dsname, f"Value: {r} -> {o}")) - else: - changed.append((dsname, "Data changed")) + differences = handle_differences( + ref_chunk, out_chunk, ref_ds, out_ds) + changed += differences + if differences: break return ComparisonResult(