context extraction imporvements

This commit is contained in:
Kevin Turcios 2026-02-18 09:09:36 -05:00
parent 7f5e163e38
commit b269212edd
2 changed files with 49 additions and 55 deletions

View file

@ -950,7 +950,7 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
emitted_class_names.add(class_name)
for name, module_name in imported_names.items():
if name in existing_classes:
if name in existing_classes or module_name == "__future__":
continue
try:
test_code = f"import {module_name}"
@ -964,6 +964,13 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
if not module_path:
continue
resolved_module = module_path.resolve()
module_str = str(resolved_module)
is_project = module_str.startswith(str(project_root_path.resolve()))
is_third_party = "site-packages" in module_str
if not is_project and not is_third_party:
continue
mod_result = get_module_source_and_tree(module_path)
if mod_result is None:
continue

View file

@ -3501,8 +3501,8 @@ class ConfigRegistry:
assert "model_list: list" in all_extracted_code, "Should include model_list field from Router"
def test_enrich_testgen_context_extracts_userdict(tmp_path: Path) -> None:
"""Extracts __init__ from collections.UserDict when a class inherits from it."""
def test_enrich_testgen_context_skips_stdlib_userdict(tmp_path: Path) -> None:
"""Skips stdlib classes like collections.UserDict."""
code = """from collections import UserDict
class MyCustomDict(UserDict):
@ -3514,14 +3514,7 @@ class MyCustomDict(UserDict):
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = enrich_testgen_context(context, tmp_path)
assert len(result.code_strings) == 1
code_string = result.code_strings[0]
assert "class UserDict" in code_string.code
assert "def __init__" in code_string.code
assert "self.data = {}" in code_string.code
assert code_string.file_path is not None
assert code_string.file_path.as_posix().endswith("collections/__init__.py")
assert len(result.code_strings) == 0, "Should not extract stdlib classes"
def test_enrich_testgen_context_skips_unresolvable_base_classes(tmp_path: Path) -> None:
@ -3555,24 +3548,24 @@ def test_enrich_testgen_context_skips_builtin_base_classes(tmp_path: Path) -> No
def test_enrich_testgen_context_deduplicates(tmp_path: Path) -> None:
"""Extracts the same external base class only once even when inherited multiple times."""
code = """from collections import UserDict
"""Extracts the same project class only once even when imported multiple times."""
package_dir = tmp_path / "mypkg"
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
(package_dir / "base.py").write_text(
"class Base:\n def __init__(self, x: int):\n self.x = x\n",
encoding="utf-8",
)
class MyDict1(UserDict):
pass
class MyDict2(UserDict):
pass
"""
code_path = tmp_path / "mydicts.py"
code = "from mypkg.base import Base\n\nclass A(Base):\n pass\n\nclass B(Base):\n pass\n"
code_path = package_dir / "children.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = enrich_testgen_context(context, tmp_path)
assert len(result.code_strings) == 1
assert "class UserDict" in result.code_strings[0].code
assert "def __init__" in result.code_strings[0].code
assert "class Base" in result.code_strings[0].code
def test_enrich_testgen_context_empty_when_no_inheritance(tmp_path: Path) -> None:
@ -3699,18 +3692,17 @@ def reify_channel_message(data: dict) -> MessageIn:
def test_testgen_context_includes_external_base_inits(tmp_path: Path) -> None:
"""Test that external base class __init__ methods are included in testgen context.
"""Test that base class definitions from project modules are included in testgen context."""
package_dir = tmp_path / "mypkg"
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
(package_dir / "base.py").write_text(
"class BaseDict:\n def __init__(self, data=None):\n self.data = data or {}\n",
encoding="utf-8",
)
This covers line 65 in code_context_extractor.py where external_base_inits.code_strings
are appended to the testgen context when a class inherits from an external library.
"""
code = """from collections import UserDict
class MyCustomDict(UserDict):
def target_method(self):
return self.data
"""
file_path = tmp_path / "test_code.py"
code = "from mypkg.base import BaseDict\n\nclass MyCustomDict(BaseDict):\n def target_method(self):\n return self.data\n"
file_path = package_dir / "test_code.py"
file_path.write_text(code, encoding="utf-8")
func_to_optimize = FunctionToOptimize(
@ -3721,11 +3713,10 @@ class MyCustomDict(UserDict):
code_ctx = get_code_optimization_context(function_to_optimize=func_to_optimize, project_root_path=tmp_path)
# The testgen context should include the UserDict __init__ method
testgen_context = code_ctx.testgen_context.markdown
assert "class UserDict" in testgen_context, "UserDict class should be in testgen context"
assert "def __init__" in testgen_context, "UserDict __init__ should be in testgen context"
assert "self.data = {}" in testgen_context, "UserDict __init__ body should be included"
assert "class BaseDict" in testgen_context, "BaseDict class should be in testgen context"
assert "def __init__" in testgen_context, "BaseDict __init__ should be in testgen context"
assert "self.data" in testgen_context, "BaseDict __init__ body should be included"
def test_testgen_raises_when_exceeds_limit(tmp_path: Path) -> None:
@ -3756,26 +3747,24 @@ def target_function():
def test_enrich_testgen_context_attribute_base(tmp_path: Path) -> None:
"""Test handling of base class accessed as module.ClassName (ast.Attribute).
"""Test handling of base class in a project module."""
package_dir = tmp_path / "mypkg"
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
(package_dir / "base.py").write_text(
"class CustomDict:\n def __init__(self, data=None):\n self.data = data or {}\n",
encoding="utf-8",
)
This covers line 616 in code_context_extractor.py.
"""
# Use the standard import style which the code actually handles
code = """from collections import UserDict
class MyDict(UserDict):
def custom_method(self):
return self.data
"""
code_path = tmp_path / "mydict.py"
code = "from mypkg.base import CustomDict\n\nclass MyDict(CustomDict):\n def custom_method(self):\n return self.data\n"
code_path = package_dir / "mydict.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = enrich_testgen_context(context, tmp_path)
# Should extract UserDict
assert len(result.code_strings) == 1
assert "class UserDict" in result.code_strings[0].code
assert "class CustomDict" in result.code_strings[0].code
assert "def __init__" in result.code_strings[0].code
@ -4026,8 +4015,8 @@ def my_func() -> None:
assert result.code_strings == []
def test_enrich_testgen_context_skips_object_init(tmp_path: Path) -> None:
"""QName has a real class definition in stdlib source, so it gets extracted."""
def test_enrich_testgen_context_skips_stdlib(tmp_path: Path) -> None:
"""Skips stdlib classes like QName."""
code = """from xml.etree.ElementTree import QName
def my_func(q: QName) -> None:
@ -4039,9 +4028,7 @@ def my_func(q: QName) -> None:
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = enrich_testgen_context(context, tmp_path)
# QName has its own class definition in ElementTree source
assert len(result.code_strings) == 1
assert "class QName" in result.code_strings[0].code
assert result.code_strings == [], "Should not extract stdlib classes"
def test_enrich_testgen_context_empty_when_no_imports(tmp_path: Path) -> None: