add more numpy testing funcs
This commit is contained in:
parent
5c91e99602
commit
8ee83085c3
1 changed files with 13 additions and 1 deletions
|
|
@ -316,9 +316,21 @@ class InjectPerfAndLogging(ast.NodeTransformer):
|
||||||
elif isinstance(stmt, ast.Expr) and isinstance(
|
elif isinstance(stmt, ast.Expr) and isinstance(
|
||||||
stmt.value, ast.Call
|
stmt.value, ast.Call
|
||||||
): # handles https://linear.app/codeflash-ai/issue/CF-500/numpy-compatibility#comment-6d956d3e
|
): # handles https://linear.app/codeflash-ai/issue/CF-500/numpy-compatibility#comment-6d956d3e
|
||||||
|
numpy_test_funcs = (
|
||||||
|
"assert_array_almost_equal",
|
||||||
|
"assert_array_equal",
|
||||||
|
"assert_allclose",
|
||||||
|
"assert_almost_equal",
|
||||||
|
"assert_equal",
|
||||||
|
"assert_array_almost_equal_nulp",
|
||||||
|
"assert_array_max_ulp",
|
||||||
|
"assert_array_less",
|
||||||
|
"assert_string_equal",
|
||||||
|
"assert_approx_equal",
|
||||||
|
)
|
||||||
if (
|
if (
|
||||||
isinstance(stmt.value.func, ast.Attribute)
|
isinstance(stmt.value.func, ast.Attribute)
|
||||||
and stmt.value.func.attr in ("assert_array_almost_equal", "assert_array_equal")
|
and stmt.value.func.attr in numpy_test_funcs
|
||||||
and stmt.value.args
|
and stmt.value.args
|
||||||
):
|
):
|
||||||
stmt.value.args[0] = ast.Name(id="codeflash_return_value", ctx=ast.Load())
|
stmt.value.args[0] = ast.Name(id="codeflash_return_value", ctx=ast.Load())
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue