mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Added logic / tests for codeflash capture to work if the class is instantiated at the module level, or in another function.
This commit is contained in:
parent
9fbfc27188
commit
750a4d0a04
2 changed files with 242 additions and 22 deletions
|
|
@ -20,9 +20,8 @@ def get_test_info_from_stack() -> tuple[str, str | None, str, str]:
|
|||
test_module_name = ""
|
||||
test_class_name = None
|
||||
test_name = None
|
||||
function_name = ""
|
||||
line_id = "" # Note that the way this line_id is defined is from the line_id called in our usual instrumentation
|
||||
function_found = False
|
||||
line_id = "" # Note that the way this line_id is defined is from the line_id called in instrumentation
|
||||
|
||||
# Search through stack for test information
|
||||
for frame in stack:
|
||||
if frame.function.startswith("test_"): # May need a more robust way to find the test file
|
||||
|
|
@ -32,26 +31,18 @@ def get_test_info_from_stack() -> tuple[str, str | None, str, str]:
|
|||
# Check if it's a method in a class
|
||||
if "self" in frame.frame.f_locals:
|
||||
test_class_name = frame.frame.f_locals["self"].__class__.__name__
|
||||
function_found = True
|
||||
break
|
||||
|
||||
if not function_found: # Likely defined at module level, or as a helper test function.
|
||||
for frame in stack:
|
||||
# First try to get the module name directly from the frame
|
||||
module = inspect.getmodule(frame[0])
|
||||
if module:
|
||||
test_module_name = module.__name__
|
||||
line_id = str(frame.lineno)
|
||||
|
||||
# If it's in a function, the function name will be in frame.function
|
||||
# If at module level, frame.function will be '<module>'
|
||||
if frame.function != "<module>":
|
||||
test_name = frame.function
|
||||
|
||||
# Check if it's in a class
|
||||
if "self" in frame.frame.f_locals:
|
||||
test_class_name = frame.frame.f_locals["self"].__class__.__name__
|
||||
break
|
||||
# Check if module name starts with test
|
||||
module_name = frame.frame.f_globals["__name__"]
|
||||
if module_name and module_name.split(".")[-1].startswith("test_"):
|
||||
test_module_name = module_name
|
||||
line_id = str(frame.lineno)
|
||||
if frame.function != "<module>":
|
||||
test_name = frame.function # Technically not a test, but save the info since there is no test function
|
||||
# Check if it's in a class
|
||||
if "self" in frame.frame.f_locals:
|
||||
test_class_name = frame.frame.f_locals["self"].__class__.__name__
|
||||
break
|
||||
|
||||
return test_module_name, test_class_name, test_name, line_id
|
||||
|
||||
|
|
|
|||
|
|
@ -95,6 +95,197 @@ class MyClass:
|
|||
sample_code_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_get_stack_info_2() -> None:
|
||||
test_code = """
|
||||
from sample_code import MyClass
|
||||
import unittest
|
||||
|
||||
obj = MyClass()
|
||||
def test_example_test():
|
||||
assert obj.x == 2
|
||||
|
||||
class TestExampleClass:
|
||||
def test_example_test_2(self):
|
||||
assert obj.x == 2
|
||||
|
||||
class TestUnittestExample(unittest.TestCase):
|
||||
def test_example_test_3(self):
|
||||
self.assertEqual(obj.x, 2)
|
||||
"""
|
||||
sample_code = """
|
||||
from codeflash.verification.codeflash_capture import get_test_info_from_stack
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 2
|
||||
print(f"TEST_INFO_START|{get_test_info_from_stack()}|TEST_INFO_END")
|
||||
"""
|
||||
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
|
||||
test_file_name = "test_stack_info_temp.py"
|
||||
|
||||
test_path = test_dir / test_file_name
|
||||
sample_code_path = test_dir / "sample_code.py"
|
||||
try:
|
||||
with test_path.open("w") as f:
|
||||
f.write(test_code)
|
||||
with sample_code_path.open("w") as f:
|
||||
f.write(sample_code)
|
||||
result = execute_test_subprocess(
|
||||
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
|
||||
)
|
||||
assert not result.stderr
|
||||
assert result.returncode == 0
|
||||
pattern = r"TEST_INFO_START\|\((.*?)\)\|TEST_INFO_END"
|
||||
matches = re.finditer(pattern, result.stdout)
|
||||
if not matches:
|
||||
raise ValueError("Could not find test info in output")
|
||||
results = []
|
||||
for match in matches:
|
||||
values = [val.strip().strip("'") for val in match.group(1).split(",")]
|
||||
results.append(values)
|
||||
# Format is (test_module_name, test_class_name, test_name, line_id)
|
||||
assert len(results) == 1
|
||||
assert results[0][0] == "code_to_optimize.tests.pytest.test_stack_info_temp" # test_module_name
|
||||
assert results[0][1].strip() == "None" # test_class_name
|
||||
assert results[0][2].strip() == "None" # test_name
|
||||
assert results[0][3] == "5" # line_id
|
||||
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
sample_code_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_get_stack_info_3() -> None:
|
||||
test_code = """
|
||||
from sample_code import MyClass
|
||||
import unittest
|
||||
|
||||
def get_obj():
|
||||
return MyClass()
|
||||
|
||||
def test_example_test():
|
||||
result = get_obj().x
|
||||
assert result == 2
|
||||
|
||||
class TestExampleClass:
|
||||
def test_example_test_2(self):
|
||||
result = get_obj().x
|
||||
assert result == 2
|
||||
|
||||
class TestUnittestExample(unittest.TestCase):
|
||||
def test_example_test_3(self):
|
||||
result = get_obj().x
|
||||
self.assertEqual(result, 2)
|
||||
"""
|
||||
sample_code = """
|
||||
from codeflash.verification.codeflash_capture import get_test_info_from_stack
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 2
|
||||
print(f"TEST_INFO_START|{get_test_info_from_stack()}|TEST_INFO_END")
|
||||
"""
|
||||
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
|
||||
test_file_name = "test_stack_info_temp.py"
|
||||
|
||||
test_path = test_dir / test_file_name
|
||||
sample_code_path = test_dir / "sample_code.py"
|
||||
try:
|
||||
with test_path.open("w") as f:
|
||||
f.write(test_code)
|
||||
with sample_code_path.open("w") as f:
|
||||
f.write(sample_code)
|
||||
result = execute_test_subprocess(
|
||||
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
|
||||
)
|
||||
assert not result.stderr
|
||||
assert result.returncode == 0
|
||||
pattern = r"TEST_INFO_START\|\((.*?)\)\|TEST_INFO_END"
|
||||
matches = re.finditer(pattern, result.stdout)
|
||||
if not matches:
|
||||
raise ValueError("Could not find test info in output")
|
||||
results = []
|
||||
for match in matches:
|
||||
values = [val.strip().strip("'") for val in match.group(1).split(",")]
|
||||
results.append(values)
|
||||
# Format is (test_module_name, test_class_name, test_name, line_id)
|
||||
assert len(results) == 3
|
||||
assert results[0][0] == "code_to_optimize.tests.pytest.test_stack_info_temp" # test_module_name
|
||||
assert results[0][1].strip() == "None" # test_class_name
|
||||
assert results[0][2].strip() == "get_obj" # test_name
|
||||
assert results[0][3] == "6" # line_id
|
||||
assert results[0] == results[1]
|
||||
assert results[1] == results[2]
|
||||
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
sample_code_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_get_stack_info_mixed() -> None:
|
||||
test_code = """
|
||||
from sample_code import MyClass
|
||||
import unittest
|
||||
|
||||
obj = MyClass()
|
||||
|
||||
def get_diff_obj():
|
||||
return MyClass()
|
||||
|
||||
def test_example_test():
|
||||
this_obj = MyClass()
|
||||
assert this_obj.x == get_diff_obj().x
|
||||
"""
|
||||
sample_code = """
|
||||
from codeflash.verification.codeflash_capture import get_test_info_from_stack
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 2
|
||||
print(f"TEST_INFO_START|{get_test_info_from_stack()}|TEST_INFO_END")
|
||||
"""
|
||||
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
|
||||
test_file_name = "test_stack_info_temp.py"
|
||||
|
||||
test_path = test_dir / test_file_name
|
||||
sample_code_path = test_dir / "sample_code.py"
|
||||
try:
|
||||
with test_path.open("w") as f:
|
||||
f.write(test_code)
|
||||
with sample_code_path.open("w") as f:
|
||||
f.write(sample_code)
|
||||
result = execute_test_subprocess(
|
||||
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
|
||||
)
|
||||
assert not result.stderr
|
||||
assert result.returncode == 0
|
||||
pattern = r"TEST_INFO_START\|\((.*?)\)\|TEST_INFO_END"
|
||||
matches = re.finditer(pattern, result.stdout)
|
||||
if not matches:
|
||||
raise ValueError("Could not find test info in output")
|
||||
results = []
|
||||
for match in matches:
|
||||
values = [val.strip().strip("'") for val in match.group(1).split(",")]
|
||||
results.append(values)
|
||||
# Format is (test_module_name, test_class_name, test_name, line_id)
|
||||
|
||||
assert results[0][0] == "code_to_optimize.tests.pytest.test_stack_info_temp" # test_module_name
|
||||
assert results[0][1].strip() == "None" # test_class_name
|
||||
assert results[0][2].strip() == "None" # test_name
|
||||
assert results[0][3] == "5" # line_id
|
||||
|
||||
assert results[1][0] == "code_to_optimize.tests.pytest.test_stack_info_temp" # test_module_name
|
||||
assert results[1][1].strip() == "None" # test_class_name
|
||||
assert results[1][2].strip() == "test_example_test" # test_name
|
||||
assert results[1][3] == "11" # line_id
|
||||
|
||||
assert results[2][0] == "code_to_optimize.tests.pytest.test_stack_info_temp" # test_module_name
|
||||
assert results[2][1].strip() == "None" # test_class_name
|
||||
assert results[2][2].strip() == "get_diff_obj" # test_name
|
||||
assert results[2][3] == "8" # line_id
|
||||
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
sample_code_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_codeflash_capture_basic() -> None:
|
||||
test_code = """
|
||||
from code_to_optimize.tests.pytest.sample_code import MyClass
|
||||
|
|
@ -674,6 +865,44 @@ class MyClass:
|
|||
opt.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
assert not compare_test_results(test_results, mutated_test_results)
|
||||
|
||||
# This fto code stopped using a helper class. it should still pass
|
||||
no_helper1_fto_code = """
|
||||
from code_to_optimize.tests.pytest.helper_file_2 import HelperClass2, AnotherHelperClass
|
||||
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def target_function(self):
|
||||
helper2 = HelperClass2().helper2()
|
||||
another = AnotherHelperClass().another_helper()
|
||||
return helper2 + another
|
||||
"""
|
||||
with fto_file_path.open("w") as f:
|
||||
f.write(no_helper1_fto_code)
|
||||
# Instrument codeflash capture
|
||||
candidate_fto_code = Path(fto.file_path).read_text("utf-8")
|
||||
candidate_helper_code = {}
|
||||
for file_path in file_path_to_helper_class:
|
||||
candidate_helper_code[file_path] = Path(file_path).read_text("utf-8")
|
||||
file_path_to_helper_classes = {
|
||||
Path(helper_path_1): {"HelperClass1"},
|
||||
Path(helper_path_2): {"HelperClass2", "AnotherHelperClass"},
|
||||
}
|
||||
instrument_code(fto, file_path_to_helper_classes)
|
||||
no_helper1_test_results, coverage_data = opt.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
test_files=test_files,
|
||||
optimization_iteration=0,
|
||||
pytest_min_loops=1,
|
||||
pytest_max_loops=1,
|
||||
testing_time=0.1,
|
||||
)
|
||||
# Remove instrumentation
|
||||
opt.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
assert compare_test_results(test_results, no_helper1_test_results)
|
||||
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
fto_file_path.unlink(missing_ok=True)
|
||||
|
|
|
|||
Loading…
Reference in a new issue