diff --git a/tests/test_reference_runs/test_pre_deployment.py b/tests/test_reference_runs/test_pre_deployment.py index 97c768ef3b3406a2165af8709a25b4f7c5c50662..ee4bda713c86426f0ddd0d38b7ebf6b4b5c43091 100644 --- a/tests/test_reference_runs/test_pre_deployment.py +++ b/tests/test_reference_runs/test_pre_deployment.py @@ -75,7 +75,9 @@ def validate_file( eq = False # One is an array, the other not elif isinstance(ref_arr, np.ndarray): # Both arrays - eq = np.array_equal(ref_arr, out_arr, equal_nan=True) + nanable = np.issubdtype(ref_arr.dtype, np.floating) \ + or np.issubdtype(ref_arr.dtype, np.complexfloating) + eq = np.array_equal(ref_arr, out_arr, equal_nan=nanable) else: # Both single values eq = ref_arr == out_arr