fix overlappings args in codeflash wrap

This commit is contained in:
aseembits93 2025-09-16 15:29:08 -07:00
parent 9ac5d34df6
commit 1a5f1034eb
3 changed files with 60 additions and 54 deletions

View file

@ -365,15 +365,15 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
targets=[ast.Name(id="test_id", ctx=ast.Store())],
value=ast.JoinedStr(
values=[
ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="test_class_name", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(value=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(value=ast.Name(id="codeflash_line_id", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1),
]
),
lineno=lineno + 1,
@ -453,7 +453,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
targets=[ast.Name(id="invocation_id", ctx=ast.Store())],
value=ast.JoinedStr(
values=[
ast.FormattedValue(value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(value=ast.Name(id="codeflash_line_id", ctx=ast.Load()), conversion=-1),
ast.Constant(value="_"),
ast.FormattedValue(value=ast.Name(id="codeflash_test_index", ctx=ast.Load()), conversion=-1),
]
@ -466,13 +466,15 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
targets=[ast.Name(id="test_stdout_tag", ctx=ast.Store())],
value=ast.JoinedStr(
values=[
ast.FormattedValue(value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(
value=ast.Name(id="codeflash_test_module_name", ctx=ast.Load()), conversion=-1
),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.IfExp(
test=ast.Name(id="test_class_name", ctx=ast.Load()),
test=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()),
body=ast.BinOp(
left=ast.Name(id="test_class_name", ctx=ast.Load()),
left=ast.Name(id="codeflash_test_class_name", ctx=ast.Load()),
op=ast.Add(),
right=ast.Constant(value="."),
),
@ -480,11 +482,15 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
),
conversion=-1,
),
ast.FormattedValue(value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(value=ast.Name(id="codeflash_test_name", ctx=ast.Load()), conversion=-1),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="function_name", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(
value=ast.Name(id="codeflash_function_name", ctx=ast.Load()), conversion=-1
),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="loop_index", ctx=ast.Load()), conversion=-1),
ast.FormattedValue(
value=ast.Name(id="codeflash_loop_index", ctx=ast.Load()), conversion=-1
),
ast.Constant(value=":"),
ast.FormattedValue(value=ast.Name(id="invocation_id", ctx=ast.Load()), conversion=-1),
]
@ -537,7 +543,7 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
ast.Assign(
targets=[ast.Name(id="return_value", ctx=ast.Store())],
value=ast.Call(
func=ast.Name(id="wrapped", ctx=ast.Load()),
func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()),
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
),
@ -664,11 +670,11 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
ast.Constant(value="INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"),
ast.Tuple(
elts=[
ast.Name(id="test_module_name", ctx=ast.Load()),
ast.Name(id="test_class_name", ctx=ast.Load()),
ast.Name(id="test_name", ctx=ast.Load()),
ast.Name(id="function_name", ctx=ast.Load()),
ast.Name(id="loop_index", ctx=ast.Load()),
ast.Name(id="codeflash_test_module_name", ctx=ast.Load()),
ast.Name(id="codeflash_test_class_name", ctx=ast.Load()),
ast.Name(id="codeflash_test_name", ctx=ast.Load()),
ast.Name(id="codeflash_function_name", ctx=ast.Load()),
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
ast.Name(id="invocation_id", ctx=ast.Load()),
ast.Name(id="codeflash_duration", ctx=ast.Load()),
ast.Name(id="pickled_return_value", ctx=ast.Load()),
@ -707,13 +713,13 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
name="codeflash_wrap",
args=ast.arguments(
args=[
ast.arg(arg="wrapped", annotation=None),
ast.arg(arg="test_module_name", annotation=None),
ast.arg(arg="test_class_name", annotation=None),
ast.arg(arg="test_name", annotation=None),
ast.arg(arg="function_name", annotation=None),
ast.arg(arg="line_id", annotation=None),
ast.arg(arg="loop_index", annotation=None),
ast.arg(arg="codeflash_wrapped", annotation=None),
ast.arg(arg="codeflash_test_module_name", annotation=None),
ast.arg(arg="codeflash_test_class_name", annotation=None),
ast.arg(arg="codeflash_test_name", annotation=None),
ast.arg(arg="codeflash_function_name", annotation=None),
ast.arg(arg="codeflash_line_id", annotation=None),
ast.arg(arg="codeflash_loop_index", annotation=None),
*([ast.arg(arg="codeflash_cur", annotation=None)] if mode == TestingMode.BEHAVIOR else []),
*([ast.arg(arg="codeflash_con", annotation=None)] if mode == TestingMode.BEHAVIOR else []),
],

View file

@ -15,8 +15,8 @@ from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
# Used by cli instrumentation
codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}'
codeflash_wrap_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {{}}
if test_id in codeflash_wrap.index:
@ -24,14 +24,14 @@ codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_cl
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{{line_id}}_{{codeflash_test_index}}'
test_stdout_tag = f"{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}"
invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}'
test_stdout_tag = f"{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}"
print(f"!$######{{test_stdout_tag}}######$!")
exception = None
gc.disable()
try:
counter = time.perf_counter_ns()
return_value = wrapped(*args, **kwargs)
return_value = codeflash_wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
@ -39,7 +39,7 @@ codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_cl
gc.enable()
print(f"!######{{test_stdout_tag}}######!")
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception

View file

@ -27,8 +27,8 @@ from codeflash.models.models import (
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}'
codeflash_wrap_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {{}}
if test_id in codeflash_wrap.index:
@ -36,14 +36,14 @@ codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_cl
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{{line_id}}_{{codeflash_test_index}}'
test_stdout_tag = f"{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}"
invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}'
test_stdout_tag = f"{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}"
print(f"!$######{{test_stdout_tag}}######$!")
exception = None
gc.disable()
try:
counter = time.perf_counter_ns()
return_value = wrapped(*args, **kwargs)
return_value = codeflash_wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
@ -51,15 +51,15 @@ codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_cl
gc.enable()
print(f"!######{{test_stdout_tag}}######!")
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
return return_value
"""
codeflash_wrap_perfonly_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, *args, **kwargs):
test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}'
codeflash_wrap_perfonly_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, *args, **kwargs):
test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {{}}
if test_id in codeflash_wrap.index:
@ -67,14 +67,14 @@ codeflash_wrap_perfonly_string = """def codeflash_wrap(wrapped, test_module_name
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{{line_id}}_{{codeflash_test_index}}'
test_stdout_tag = f"{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}"
invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}'
test_stdout_tag = f"{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}"
print(f"!$######{{test_stdout_tag}}######$!")
exception = None
gc.disable()
try:
counter = time.perf_counter_ns()
return_value = wrapped(*args, **kwargs)
return_value = codeflash_wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
@ -118,8 +118,8 @@ import timeout_decorator
from code_to_optimize.bubble_sort import sorter
def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}'
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {{}}
if test_id in codeflash_wrap.index:
@ -127,16 +127,16 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{{line_id}}_{{codeflash_test_index}}'
invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}'
"""
expected += """test_stdout_tag = f'{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}'
expected += """test_stdout_tag = f'{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}'
"""
expected += """print(f'!$######{{test_stdout_tag}}######$!')
exception = None
gc.disable()
try:
counter = time.perf_counter_ns()
return_value = wrapped(*args, **kwargs)
return_value = codeflash_wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
@ -144,7 +144,7 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi
gc.enable()
print(f'!######{{test_stdout_tag}}######!')
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
@ -218,8 +218,8 @@ from codeflash.tracing.replay_test import get_next_arg_and_return
from codeflash.validation.equivalence import compare_results
def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}'
def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {{}}
if test_id in codeflash_wrap.index:
@ -227,16 +227,16 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{{line_id}}_{{codeflash_test_index}}'
invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}'
"""
expected += """test_stdout_tag = f'{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}'
expected += """test_stdout_tag = f'{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}'
"""
expected += """print(f'!$######{{test_stdout_tag}}######$!')
exception = None
gc.disable()
try:
counter = time.perf_counter_ns()
return_value = wrapped(*args, **kwargs)
return_value = codeflash_wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
@ -244,7 +244,7 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi
gc.enable()
print(f'!######{{test_stdout_tag}}######!')
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception