Added logic / tests for codeflash capture to work if the class is instantiated at the module level, or in another function.

This commit is contained in:
Alvin Ryanputra 2025-01-22 16:17:07 -08:00
parent 9fbfc27188
commit 750a4d0a04
2 changed files with 242 additions and 22 deletions

View file

@ -20,9 +20,8 @@ def get_test_info_from_stack() -> tuple[str, str | None, str, str]:
test_module_name = ""
test_class_name = None
test_name = None
function_name = ""
line_id = "" # Note that the way this line_id is defined is from the line_id called in our usual instrumentation
function_found = False
line_id = "" # Note that the way this line_id is defined is from the line_id called in instrumentation
# Search through stack for test information
for frame in stack:
if frame.function.startswith("test_"): # May need a more robust way to find the test file
@ -32,26 +31,18 @@ def get_test_info_from_stack() -> tuple[str, str | None, str, str]:
# Check if it's a method in a class
if "self" in frame.frame.f_locals:
test_class_name = frame.frame.f_locals["self"].__class__.__name__
function_found = True
break
if not function_found: # Likely defined at module level, or as a helper test function.
for frame in stack:
# First try to get the module name directly from the frame
module = inspect.getmodule(frame[0])
if module:
test_module_name = module.__name__
line_id = str(frame.lineno)
# If it's in a function, the function name will be in frame.function
# If at module level, frame.function will be '<module>'
if frame.function != "<module>":
test_name = frame.function
# Check if it's in a class
if "self" in frame.frame.f_locals:
test_class_name = frame.frame.f_locals["self"].__class__.__name__
break
# Check if module name starts with test
module_name = frame.frame.f_globals["__name__"]
if module_name and module_name.split(".")[-1].startswith("test_"):
test_module_name = module_name
line_id = str(frame.lineno)
if frame.function != "<module>":
test_name = frame.function # Technically not a test, but save the info since there is no test function
# Check if it's in a class
if "self" in frame.frame.f_locals:
test_class_name = frame.frame.f_locals["self"].__class__.__name__
break
return test_module_name, test_class_name, test_name, line_id

View file

@ -95,6 +95,197 @@ class MyClass:
sample_code_path.unlink(missing_ok=True)
def test_get_stack_info_2() -> None:
test_code = """
from sample_code import MyClass
import unittest
obj = MyClass()
def test_example_test():
assert obj.x == 2
class TestExampleClass:
def test_example_test_2(self):
assert obj.x == 2
class TestUnittestExample(unittest.TestCase):
def test_example_test_3(self):
self.assertEqual(obj.x, 2)
"""
sample_code = """
from codeflash.verification.codeflash_capture import get_test_info_from_stack
class MyClass:
def __init__(self):
self.x = 2
print(f"TEST_INFO_START|{get_test_info_from_stack()}|TEST_INFO_END")
"""
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
test_file_name = "test_stack_info_temp.py"
test_path = test_dir / test_file_name
sample_code_path = test_dir / "sample_code.py"
try:
with test_path.open("w") as f:
f.write(test_code)
with sample_code_path.open("w") as f:
f.write(sample_code)
result = execute_test_subprocess(
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
)
assert not result.stderr
assert result.returncode == 0
pattern = r"TEST_INFO_START\|\((.*?)\)\|TEST_INFO_END"
matches = re.finditer(pattern, result.stdout)
if not matches:
raise ValueError("Could not find test info in output")
results = []
for match in matches:
values = [val.strip().strip("'") for val in match.group(1).split(",")]
results.append(values)
# Format is (test_module_name, test_class_name, test_name, line_id)
assert len(results) == 1
assert results[0][0] == "code_to_optimize.tests.pytest.test_stack_info_temp" # test_module_name
assert results[0][1].strip() == "None" # test_class_name
assert results[0][2].strip() == "None" # test_name
assert results[0][3] == "5" # line_id
finally:
test_path.unlink(missing_ok=True)
sample_code_path.unlink(missing_ok=True)
def test_get_stack_info_3() -> None:
test_code = """
from sample_code import MyClass
import unittest
def get_obj():
return MyClass()
def test_example_test():
result = get_obj().x
assert result == 2
class TestExampleClass:
def test_example_test_2(self):
result = get_obj().x
assert result == 2
class TestUnittestExample(unittest.TestCase):
def test_example_test_3(self):
result = get_obj().x
self.assertEqual(result, 2)
"""
sample_code = """
from codeflash.verification.codeflash_capture import get_test_info_from_stack
class MyClass:
def __init__(self):
self.x = 2
print(f"TEST_INFO_START|{get_test_info_from_stack()}|TEST_INFO_END")
"""
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
test_file_name = "test_stack_info_temp.py"
test_path = test_dir / test_file_name
sample_code_path = test_dir / "sample_code.py"
try:
with test_path.open("w") as f:
f.write(test_code)
with sample_code_path.open("w") as f:
f.write(sample_code)
result = execute_test_subprocess(
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
)
assert not result.stderr
assert result.returncode == 0
pattern = r"TEST_INFO_START\|\((.*?)\)\|TEST_INFO_END"
matches = re.finditer(pattern, result.stdout)
if not matches:
raise ValueError("Could not find test info in output")
results = []
for match in matches:
values = [val.strip().strip("'") for val in match.group(1).split(",")]
results.append(values)
# Format is (test_module_name, test_class_name, test_name, line_id)
assert len(results) == 3
assert results[0][0] == "code_to_optimize.tests.pytest.test_stack_info_temp" # test_module_name
assert results[0][1].strip() == "None" # test_class_name
assert results[0][2].strip() == "get_obj" # test_name
assert results[0][3] == "6" # line_id
assert results[0] == results[1]
assert results[1] == results[2]
finally:
test_path.unlink(missing_ok=True)
sample_code_path.unlink(missing_ok=True)
def test_get_stack_info_mixed() -> None:
test_code = """
from sample_code import MyClass
import unittest
obj = MyClass()
def get_diff_obj():
return MyClass()
def test_example_test():
this_obj = MyClass()
assert this_obj.x == get_diff_obj().x
"""
sample_code = """
from codeflash.verification.codeflash_capture import get_test_info_from_stack
class MyClass:
def __init__(self):
self.x = 2
print(f"TEST_INFO_START|{get_test_info_from_stack()}|TEST_INFO_END")
"""
test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve()
test_file_name = "test_stack_info_temp.py"
test_path = test_dir / test_file_name
sample_code_path = test_dir / "sample_code.py"
try:
with test_path.open("w") as f:
f.write(test_code)
with sample_code_path.open("w") as f:
f.write(sample_code)
result = execute_test_subprocess(
cwd=test_dir, env={}, cmd_list=[f"{SAFE_SYS_EXECUTABLE}", "-m", "pytest", test_file_name, "-s"]
)
assert not result.stderr
assert result.returncode == 0
pattern = r"TEST_INFO_START\|\((.*?)\)\|TEST_INFO_END"
matches = re.finditer(pattern, result.stdout)
if not matches:
raise ValueError("Could not find test info in output")
results = []
for match in matches:
values = [val.strip().strip("'") for val in match.group(1).split(",")]
results.append(values)
# Format is (test_module_name, test_class_name, test_name, line_id)
assert results[0][0] == "code_to_optimize.tests.pytest.test_stack_info_temp" # test_module_name
assert results[0][1].strip() == "None" # test_class_name
assert results[0][2].strip() == "None" # test_name
assert results[0][3] == "5" # line_id
assert results[1][0] == "code_to_optimize.tests.pytest.test_stack_info_temp" # test_module_name
assert results[1][1].strip() == "None" # test_class_name
assert results[1][2].strip() == "test_example_test" # test_name
assert results[1][3] == "11" # line_id
assert results[2][0] == "code_to_optimize.tests.pytest.test_stack_info_temp" # test_module_name
assert results[2][1].strip() == "None" # test_class_name
assert results[2][2].strip() == "get_diff_obj" # test_name
assert results[2][3] == "8" # line_id
finally:
test_path.unlink(missing_ok=True)
sample_code_path.unlink(missing_ok=True)
def test_codeflash_capture_basic() -> None:
test_code = """
from code_to_optimize.tests.pytest.sample_code import MyClass
@ -674,6 +865,44 @@ class MyClass:
opt.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
assert not compare_test_results(test_results, mutated_test_results)
# This fto code stopped using a helper class. it should still pass
no_helper1_fto_code = """
from code_to_optimize.tests.pytest.helper_file_2 import HelperClass2, AnotherHelperClass
class MyClass:
def __init__(self):
self.x = 1
def target_function(self):
helper2 = HelperClass2().helper2()
another = AnotherHelperClass().another_helper()
return helper2 + another
"""
with fto_file_path.open("w") as f:
f.write(no_helper1_fto_code)
# Instrument codeflash capture
candidate_fto_code = Path(fto.file_path).read_text("utf-8")
candidate_helper_code = {}
for file_path in file_path_to_helper_class:
candidate_helper_code[file_path] = Path(file_path).read_text("utf-8")
file_path_to_helper_classes = {
Path(helper_path_1): {"HelperClass1"},
Path(helper_path_2): {"HelperClass2", "AnotherHelperClass"},
}
instrument_code(fto, file_path_to_helper_classes)
no_helper1_test_results, coverage_data = opt.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=0.1,
)
# Remove instrumentation
opt.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
assert compare_test_results(test_results, no_helper1_test_results)
finally:
test_path.unlink(missing_ok=True)
fto_file_path.unlink(missing_ok=True)