add more numpy testing funcs

This commit is contained in:
Kevin Turcios 2025-01-23 03:28:45 -05:00
parent 5c91e99602
commit 8ee83085c3

View file

@ -316,9 +316,21 @@ class InjectPerfAndLogging(ast.NodeTransformer):
elif isinstance(stmt, ast.Expr) and isinstance( elif isinstance(stmt, ast.Expr) and isinstance(
stmt.value, ast.Call stmt.value, ast.Call
): # handles https://linear.app/codeflash-ai/issue/CF-500/numpy-compatibility#comment-6d956d3e ): # handles https://linear.app/codeflash-ai/issue/CF-500/numpy-compatibility#comment-6d956d3e
numpy_test_funcs = (
"assert_array_almost_equal",
"assert_array_equal",
"assert_allclose",
"assert_almost_equal",
"assert_equal",
"assert_array_almost_equal_nulp",
"assert_array_max_ulp",
"assert_array_less",
"assert_string_equal",
"assert_approx_equal",
)
if ( if (
isinstance(stmt.value.func, ast.Attribute) isinstance(stmt.value.func, ast.Attribute)
and stmt.value.func.attr in ("assert_array_almost_equal", "assert_array_equal") and stmt.value.func.attr in numpy_test_funcs
and stmt.value.args and stmt.value.args
): ):
stmt.value.args[0] = ast.Name(id="codeflash_return_value", ctx=ast.Load()) stmt.value.args[0] = ast.Name(id="codeflash_return_value", ctx=ast.Load())