mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
context extraction imporvements
This commit is contained in:
parent
7f5e163e38
commit
b269212edd
2 changed files with 49 additions and 55 deletions
|
|
@ -950,7 +950,7 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
|
|||
emitted_class_names.add(class_name)
|
||||
|
||||
for name, module_name in imported_names.items():
|
||||
if name in existing_classes:
|
||||
if name in existing_classes or module_name == "__future__":
|
||||
continue
|
||||
try:
|
||||
test_code = f"import {module_name}"
|
||||
|
|
@ -964,6 +964,13 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
|
|||
if not module_path:
|
||||
continue
|
||||
|
||||
resolved_module = module_path.resolve()
|
||||
module_str = str(resolved_module)
|
||||
is_project = module_str.startswith(str(project_root_path.resolve()))
|
||||
is_third_party = "site-packages" in module_str
|
||||
if not is_project and not is_third_party:
|
||||
continue
|
||||
|
||||
mod_result = get_module_source_and_tree(module_path)
|
||||
if mod_result is None:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -3501,8 +3501,8 @@ class ConfigRegistry:
|
|||
assert "model_list: list" in all_extracted_code, "Should include model_list field from Router"
|
||||
|
||||
|
||||
def test_enrich_testgen_context_extracts_userdict(tmp_path: Path) -> None:
|
||||
"""Extracts __init__ from collections.UserDict when a class inherits from it."""
|
||||
def test_enrich_testgen_context_skips_stdlib_userdict(tmp_path: Path) -> None:
|
||||
"""Skips stdlib classes like collections.UserDict."""
|
||||
code = """from collections import UserDict
|
||||
|
||||
class MyCustomDict(UserDict):
|
||||
|
|
@ -3514,14 +3514,7 @@ class MyCustomDict(UserDict):
|
|||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
code_string = result.code_strings[0]
|
||||
|
||||
assert "class UserDict" in code_string.code
|
||||
assert "def __init__" in code_string.code
|
||||
assert "self.data = {}" in code_string.code
|
||||
assert code_string.file_path is not None
|
||||
assert code_string.file_path.as_posix().endswith("collections/__init__.py")
|
||||
assert len(result.code_strings) == 0, "Should not extract stdlib classes"
|
||||
|
||||
|
||||
def test_enrich_testgen_context_skips_unresolvable_base_classes(tmp_path: Path) -> None:
|
||||
|
|
@ -3555,24 +3548,24 @@ def test_enrich_testgen_context_skips_builtin_base_classes(tmp_path: Path) -> No
|
|||
|
||||
|
||||
def test_enrich_testgen_context_deduplicates(tmp_path: Path) -> None:
|
||||
"""Extracts the same external base class only once even when inherited multiple times."""
|
||||
code = """from collections import UserDict
|
||||
"""Extracts the same project class only once even when imported multiple times."""
|
||||
package_dir = tmp_path / "mypkg"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(package_dir / "base.py").write_text(
|
||||
"class Base:\n def __init__(self, x: int):\n self.x = x\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
class MyDict1(UserDict):
|
||||
pass
|
||||
|
||||
class MyDict2(UserDict):
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "mydicts.py"
|
||||
code = "from mypkg.base import Base\n\nclass A(Base):\n pass\n\nclass B(Base):\n pass\n"
|
||||
code_path = package_dir / "children.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
assert "class UserDict" in result.code_strings[0].code
|
||||
assert "def __init__" in result.code_strings[0].code
|
||||
assert "class Base" in result.code_strings[0].code
|
||||
|
||||
|
||||
def test_enrich_testgen_context_empty_when_no_inheritance(tmp_path: Path) -> None:
|
||||
|
|
@ -3699,18 +3692,17 @@ def reify_channel_message(data: dict) -> MessageIn:
|
|||
|
||||
|
||||
def test_testgen_context_includes_external_base_inits(tmp_path: Path) -> None:
|
||||
"""Test that external base class __init__ methods are included in testgen context.
|
||||
"""Test that base class definitions from project modules are included in testgen context."""
|
||||
package_dir = tmp_path / "mypkg"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(package_dir / "base.py").write_text(
|
||||
"class BaseDict:\n def __init__(self, data=None):\n self.data = data or {}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
This covers line 65 in code_context_extractor.py where external_base_inits.code_strings
|
||||
are appended to the testgen context when a class inherits from an external library.
|
||||
"""
|
||||
code = """from collections import UserDict
|
||||
|
||||
class MyCustomDict(UserDict):
|
||||
def target_method(self):
|
||||
return self.data
|
||||
"""
|
||||
file_path = tmp_path / "test_code.py"
|
||||
code = "from mypkg.base import BaseDict\n\nclass MyCustomDict(BaseDict):\n def target_method(self):\n return self.data\n"
|
||||
file_path = package_dir / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
|
|
@ -3721,11 +3713,10 @@ class MyCustomDict(UserDict):
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize=func_to_optimize, project_root_path=tmp_path)
|
||||
|
||||
# The testgen context should include the UserDict __init__ method
|
||||
testgen_context = code_ctx.testgen_context.markdown
|
||||
assert "class UserDict" in testgen_context, "UserDict class should be in testgen context"
|
||||
assert "def __init__" in testgen_context, "UserDict __init__ should be in testgen context"
|
||||
assert "self.data = {}" in testgen_context, "UserDict __init__ body should be included"
|
||||
assert "class BaseDict" in testgen_context, "BaseDict class should be in testgen context"
|
||||
assert "def __init__" in testgen_context, "BaseDict __init__ should be in testgen context"
|
||||
assert "self.data" in testgen_context, "BaseDict __init__ body should be included"
|
||||
|
||||
|
||||
def test_testgen_raises_when_exceeds_limit(tmp_path: Path) -> None:
|
||||
|
|
@ -3756,26 +3747,24 @@ def target_function():
|
|||
|
||||
|
||||
def test_enrich_testgen_context_attribute_base(tmp_path: Path) -> None:
|
||||
"""Test handling of base class accessed as module.ClassName (ast.Attribute).
|
||||
"""Test handling of base class in a project module."""
|
||||
package_dir = tmp_path / "mypkg"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
(package_dir / "base.py").write_text(
|
||||
"class CustomDict:\n def __init__(self, data=None):\n self.data = data or {}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
This covers line 616 in code_context_extractor.py.
|
||||
"""
|
||||
# Use the standard import style which the code actually handles
|
||||
code = """from collections import UserDict
|
||||
|
||||
class MyDict(UserDict):
|
||||
def custom_method(self):
|
||||
return self.data
|
||||
"""
|
||||
code_path = tmp_path / "mydict.py"
|
||||
code = "from mypkg.base import CustomDict\n\nclass MyDict(CustomDict):\n def custom_method(self):\n return self.data\n"
|
||||
code_path = package_dir / "mydict.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract UserDict
|
||||
assert len(result.code_strings) == 1
|
||||
assert "class UserDict" in result.code_strings[0].code
|
||||
assert "class CustomDict" in result.code_strings[0].code
|
||||
assert "def __init__" in result.code_strings[0].code
|
||||
|
||||
|
||||
|
|
@ -4026,8 +4015,8 @@ def my_func() -> None:
|
|||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_enrich_testgen_context_skips_object_init(tmp_path: Path) -> None:
|
||||
"""QName has a real class definition in stdlib source, so it gets extracted."""
|
||||
def test_enrich_testgen_context_skips_stdlib(tmp_path: Path) -> None:
|
||||
"""Skips stdlib classes like QName."""
|
||||
code = """from xml.etree.ElementTree import QName
|
||||
|
||||
def my_func(q: QName) -> None:
|
||||
|
|
@ -4039,9 +4028,7 @@ def my_func(q: QName) -> None:
|
|||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# QName has its own class definition in ElementTree source
|
||||
assert len(result.code_strings) == 1
|
||||
assert "class QName" in result.code_strings[0].code
|
||||
assert result.code_strings == [], "Should not extract stdlib classes"
|
||||
|
||||
|
||||
def test_enrich_testgen_context_empty_when_no_imports(tmp_path: Path) -> None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue