feat: extend testgen type context to include function body references

Extract types referenced in the function body (constructor calls, attribute
access, isinstance/issubclass args) in addition to parameter annotations.
Use full class extraction instead of init-stub-only, with instance resolution
fallback and project/site-packages filtering.
This commit is contained in:
Kevin Turcios 2026-02-18 13:41:50 -05:00
parent 2364096cc1
commit 2966e15775
4 changed files with 628 additions and 31 deletions

View file

@ -1,2 +1,3 @@
# Managed by Tessl
tessl:*
tessl__*

View file

@ -1,2 +1,3 @@
# Managed by Tessl
tessl:*
tessl__*

View file

@ -68,15 +68,26 @@ def build_testgen_context(
if enrichment.code_strings:
testgen_context = CodeStringsMarkdown(code_strings=testgen_context.code_strings + enrichment.code_strings)
type_context_strings: list[CodeString] = []
if function_to_optimize is not None:
result = _parse_and_collect_imports(testgen_context)
existing_classes = collect_existing_class_names(result[0]) if result else set()
constructor_stubs = extract_parameter_type_constructors(
function_to_optimize, project_root_path, existing_classes
)
if constructor_stubs.code_strings:
type_context = extract_type_context_for_testgen(function_to_optimize, project_root_path, existing_classes)
if type_context.code_strings:
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + constructor_stubs.code_strings
code_strings=testgen_context.code_strings + type_context.code_strings
)
type_context_strings = type_context.code_strings
# Enrich field types from all newly extracted classes (enrichment + type context)
new_classes = CodeStringsMarkdown(code_strings=enrichment.code_strings + type_context_strings)
if new_classes.code_strings:
updated_result = _parse_and_collect_imports(testgen_context)
updated_existing = collect_existing_class_names(updated_result[0]) if updated_result else set()
field_type_enrichment = enrich_type_context_classes(new_classes, updated_existing, project_root_path)
if field_type_enrichment.code_strings:
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + field_type_enrichment.code_strings
)
return testgen_context
@ -744,6 +755,48 @@ def collect_type_names_from_annotation(node: ast.expr | None) -> set[str]:
return set()
def collect_names_from_function_body(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> set[str]:
names: set[str] = set()
for node in ast.walk(func_node):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
names.add(node.func.id)
elif isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name):
names.add(node.func.value.id)
if isinstance(node.func, ast.Name) and node.func.id in ("isinstance", "issubclass") and node.args:
second_arg = node.args[1] if len(node.args) > 1 else None
if isinstance(second_arg, ast.Name):
names.add(second_arg.id)
elif isinstance(second_arg, ast.Tuple):
for elt in second_arg.elts:
if isinstance(elt, ast.Name):
names.add(elt.id)
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
names.add(node.value.id)
return names
def extract_full_class_from_module(class_name: str, module_source: str, module_tree: ast.Module) -> str | None:
class_node = None
# Use a deque-based BFS to find the first matching ClassDef (preserves ast.walk order)
q: deque[ast.AST] = deque([module_tree])
while q:
candidate = q.popleft()
if isinstance(candidate, ast.ClassDef) and candidate.name == class_name:
class_node = candidate
break
q.extend(ast.iter_child_nodes(candidate))
if class_node is None:
return None
lines = module_source.split("\n")
start_line = class_node.lineno
if class_node.decorator_list:
start_line = min(d.lineno for d in class_node.decorator_list)
return "\n".join(lines[start_line - 1 : class_node.end_lineno])
def extract_init_stub_from_class(class_name: str, module_source: str, module_tree: ast.Module) -> str | None:
class_node = None
# Use a deque-based BFS to find the first matching ClassDef (preserves ast.walk order)
@ -793,7 +846,7 @@ def extract_init_stub_from_class(class_name: str, module_source: str, module_tre
return f"class {class_name}:\n" + "\n".join(snippets)
def extract_parameter_type_constructors(
def extract_type_context_for_testgen(
function_to_optimize: FunctionToOptimize, project_root_path: Path, existing_class_names: set[str]
) -> CodeStringsMarkdown:
import jedi
@ -825,6 +878,8 @@ def extract_parameter_type_constructors(
if func_node.args.kwarg:
type_names |= collect_type_names_from_annotation(func_node.args.kwarg.annotation)
type_names |= collect_names_from_function_body(func_node)
type_names -= BUILTIN_AND_TYPING_NAMES
type_names -= existing_class_names
if not type_names:
@ -855,6 +910,13 @@ def extract_parameter_type_constructors(
if not module_path:
continue
resolved_module = module_path.resolve()
module_str = str(resolved_module)
is_project = module_str.startswith(str(project_root_path.resolve()))
is_third_party = "site-packages" in module_str
if not is_project and not is_third_party:
continue
if module_path in module_cache:
mod_source, mod_tree = module_cache[module_path]
else:
@ -862,11 +924,15 @@ def extract_parameter_type_constructors(
mod_tree = ast.parse(mod_source)
module_cache[module_path] = (mod_source, mod_tree)
stub = extract_init_stub_from_class(type_name, mod_source, mod_tree)
if stub:
code_strings.append(CodeString(code=stub, file_path=module_path))
class_source = extract_full_class_from_module(type_name, mod_source, mod_tree)
if class_source is None:
resolved_class = resolve_instance_class_name(type_name, mod_tree)
if resolved_class and resolved_class not in existing_class_names:
class_source = extract_full_class_from_module(resolved_class, mod_source, mod_tree)
if class_source:
code_strings.append(CodeString(code=class_source, file_path=module_path))
except Exception:
logger.debug(f"Error extracting constructor stub for {type_name} from {module_name}")
logger.debug(f"Error extracting type context for {type_name} from {module_name}")
continue
return CodeStringsMarkdown(code_strings=code_strings)
@ -912,6 +978,7 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
# --- Step 1: Project class definitions (jedi resolution + recursive base extraction) ---
extracted_classes: set[tuple[Path, str]] = set()
module_cache: dict[Path, tuple[str, ast.Module]] = {}
module_import_maps: dict[Path, dict[str, str]] = {}
def get_module_source_and_tree(module_path: Path) -> tuple[str, ast.Module] | None:
if module_path in module_cache:
@ -925,10 +992,22 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
module_cache[module_path] = (module_source, module_tree)
return module_source, module_tree
def get_module_import_map(module_path: Path, module_tree: ast.Module) -> dict[str, str]:
if module_path in module_import_maps:
return module_import_maps[module_path]
import_map: dict[str, str] = {}
for node in ast.walk(module_tree):
if isinstance(node, ast.ImportFrom) and node.module:
for alias in node.names:
name = alias.asname if alias.asname else alias.name
import_map[name] = node.module
module_import_maps[module_path] = import_map
return import_map
def extract_class_and_bases(
class_name: str, module_path: Path, module_source: str, module_tree: ast.Module
class_name: str, module_path: Path, module_source: str, module_tree: ast.Module, depth: int = 0
) -> None:
if (module_path, class_name) in extracted_classes:
if depth >= 3 or (module_path, class_name) in extracted_classes:
return
class_node = None
@ -947,8 +1026,40 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
elif isinstance(base, ast.Attribute):
continue
if base_name and base_name not in existing_classes:
extract_class_and_bases(base_name, module_path, module_source, module_tree)
if not base_name or base_name in existing_classes:
continue
# Check if the base class is defined locally in this module
local_base = any(isinstance(n, ast.ClassDef) and n.name == base_name for n in ast.walk(module_tree))
if local_base:
extract_class_and_bases(base_name, module_path, module_source, module_tree, depth + 1)
else:
# Resolve cross-module base class via imports
import_map = get_module_import_map(module_path, module_tree)
base_module_name = import_map.get(base_name)
if not base_module_name:
continue
try:
script_code = f"from {base_module_name} import {base_name}"
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
definitions = script.goto(
1, len(f"from {base_module_name} import ") + len(base_name), follow_imports=True
)
if not definitions or not definitions[0].module_path:
continue
base_module_path = definitions[0].module_path
resolved_str = str(base_module_path.resolve())
is_project = resolved_str.startswith(str(project_root_path.resolve()) + os.sep)
is_third_party = "site-packages" in resolved_str
if not is_project and not is_third_party:
continue
base_mod = get_module_source_and_tree(base_module_path)
if base_mod is None:
continue
extract_class_and_bases(base_name, base_module_path, base_mod[0], base_mod[1], depth + 1)
except Exception:
logger.debug(f"Error resolving cross-module base class {base_name} from {base_module_name}")
continue
if (module_path, class_name) in extracted_classes:
return
@ -1006,6 +1117,96 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
return CodeStringsMarkdown(code_strings=code_strings)
def enrich_type_context_classes(
type_context: CodeStringsMarkdown, existing_class_names: set[str], project_root_path: Path
) -> CodeStringsMarkdown:
import jedi
code_strings: list[CodeString] = []
emitted: set[str] = set()
module_cache: dict[Path, tuple[str, ast.Module]] = {}
for cs in type_context.code_strings:
try:
snippet_tree = ast.parse(cs.code)
except SyntaxError:
continue
# Collect type names from field annotations of extracted classes
type_names: set[str] = set()
for node in ast.walk(snippet_tree):
if isinstance(node, ast.ClassDef):
for item in node.body:
if isinstance(item, ast.AnnAssign) and item.annotation:
type_names |= collect_type_names_from_annotation(item.annotation)
type_names -= BUILTIN_AND_TYPING_NAMES
type_names -= existing_class_names
type_names -= emitted
if not type_names:
continue
# Build import map from the source file, not the snippet
source_path = cs.file_path
if not source_path:
continue
import_map: dict[str, str] = {}
if source_path in module_cache:
source_code, source_tree = module_cache[source_path]
else:
try:
source_code = source_path.read_text(encoding="utf-8")
source_tree = ast.parse(source_code)
module_cache[source_path] = (source_code, source_tree)
except Exception:
continue
for snode in ast.walk(source_tree):
if isinstance(snode, ast.ImportFrom) and snode.module:
for alias in snode.names:
name = alias.asname if alias.asname else alias.name
import_map[name] = snode.module
for type_name in sorted(type_names):
module_name = import_map.get(type_name)
if not module_name:
continue
try:
script_code = f"from {module_name} import {type_name}"
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True)
if not definitions or not definitions[0].module_path:
continue
mod_path = definitions[0].module_path
resolved_str = str(mod_path.resolve())
is_project = resolved_str.startswith(str(project_root_path.resolve()))
is_third_party = "site-packages" in resolved_str
if not is_project and not is_third_party:
continue
if mod_path in module_cache:
mod_source, mod_tree = module_cache[mod_path]
else:
mod_source = mod_path.read_text(encoding="utf-8")
mod_tree = ast.parse(mod_source)
module_cache[mod_path] = (mod_source, mod_tree)
class_source = extract_full_class_from_module(type_name, mod_source, mod_tree)
if class_source is None:
resolved_class = resolve_instance_class_name(type_name, mod_tree)
if resolved_class and resolved_class not in existing_class_names and resolved_class not in emitted:
class_source = extract_full_class_from_module(resolved_class, mod_source, mod_tree)
type_name = resolved_class
if class_source:
code_strings.append(CodeString(code=class_source, file_path=mod_path))
emitted.add(type_name)
except Exception:
logger.debug(f"Error extracting type context class for {type_name} from {module_name}")
continue
return CodeStringsMarkdown(code_strings=code_strings)
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
"""Removes the docstring from an indented block if it exists."""
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):

View file

@ -13,10 +13,13 @@ from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_g
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.context.code_context_extractor import (
collect_names_from_function_body,
collect_type_names_from_annotation,
enrich_testgen_context,
enrich_type_context_classes,
extract_full_class_from_module,
extract_init_stub_from_class,
extract_parameter_type_constructors,
extract_type_context_for_testgen,
get_code_optimization_context,
resolve_instance_class_name,
)
@ -3553,8 +3556,7 @@ def test_enrich_testgen_context_deduplicates(tmp_path: Path) -> None:
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
(package_dir / "base.py").write_text(
"class Base:\n def __init__(self, x: int):\n self.x = x\n",
encoding="utf-8",
"class Base:\n def __init__(self, x: int):\n self.x = x\n", encoding="utf-8"
)
code = "from mypkg.base import Base\n\nclass A(Base):\n pass\n\nclass B(Base):\n pass\n"
@ -3697,8 +3699,7 @@ def test_testgen_context_includes_external_base_inits(tmp_path: Path) -> None:
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
(package_dir / "base.py").write_text(
"class BaseDict:\n def __init__(self, data=None):\n self.data = data or {}\n",
encoding="utf-8",
"class BaseDict:\n def __init__(self, data=None):\n self.data = data or {}\n", encoding="utf-8"
)
code = "from mypkg.base import BaseDict\n\nclass MyCustomDict(BaseDict):\n def target_method(self):\n return self.data\n"
@ -3752,8 +3753,7 @@ def test_enrich_testgen_context_attribute_base(tmp_path: Path) -> None:
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
(package_dir / "base.py").write_text(
"class CustomDict:\n def __init__(self, data=None):\n self.data = data or {}\n",
encoding="utf-8",
"class CustomDict:\n def __init__(self, data=None):\n self.data = data or {}\n", encoding="utf-8"
)
code = "from mypkg.base import CustomDict\n\nclass MyDict(CustomDict):\n def custom_method(self):\n return self.data\n"
@ -4208,10 +4208,10 @@ class Other:
assert stub is None
# --- Tests for extract_parameter_type_constructors ---
# --- Tests for extract_type_context_for_testgen ---
def test_extract_parameter_type_constructors_project_type(tmp_path: Path) -> None:
def test_extract_type_context_for_testgen_project_type(tmp_path: Path) -> None:
# Create a module with a class
pkg = tmp_path / "mypkg"
pkg.mkdir()
@ -4239,7 +4239,7 @@ def process(w: Widget) -> str:
fto = FunctionToOptimize(
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 1
code = result.code_strings[0].code
assert "class Widget:" in code
@ -4247,7 +4247,7 @@ def process(w: Widget) -> str:
assert "size" in code
def test_extract_parameter_type_constructors_excludes_builtins(tmp_path: Path) -> None:
def test_extract_type_context_for_testgen_excludes_builtins(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
@ -4262,11 +4262,11 @@ def my_func(x: int, y: str, z: list) -> None:
fto = FunctionToOptimize(
function_name="my_func", file_path=(pkg / "func.py").resolve(), starting_line=2, ending_line=3
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 0
def test_extract_parameter_type_constructors_skips_existing_classes(tmp_path: Path) -> None:
def test_extract_type_context_for_testgen_skips_existing_classes(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
@ -4291,11 +4291,11 @@ def process(w: Widget) -> str:
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
# Widget is already in the context — should not be duplicated
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), {"Widget"})
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), {"Widget"})
assert len(result.code_strings) == 0
def test_extract_parameter_type_constructors_no_init(tmp_path: Path) -> None:
def test_extract_type_context_for_testgen_no_init(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
@ -4318,8 +4318,11 @@ def process(c: Config) -> str:
fto = FunctionToOptimize(
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 0
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 1
code = result.code_strings[0].code
assert "class Config:" in code
assert "x = 10" in code
# --- Tests for resolve_instance_class_name ---
@ -4444,3 +4447,394 @@ def get_log_level() -> str:
combined = "\n".join(cs.code for cs in result.code_strings)
assert "class AppConfig:" in combined
assert "@property" in combined
# --- Tests for collect_names_from_function_body ---
def test_collect_names_from_function_body_constructor_calls() -> None:
source = """\
def f():
x = MyClass(1, 2)
y = OtherClass()
"""
tree = ast.parse(source)
func_node = tree.body[0]
names = collect_names_from_function_body(func_node)
assert "MyClass" in names
assert "OtherClass" in names
def test_collect_names_from_function_body_attribute_access() -> None:
source = """\
def f():
val = config.PROP
x = MyEnum.VALUE
"""
tree = ast.parse(source)
func_node = tree.body[0]
names = collect_names_from_function_body(func_node)
assert "config" in names
assert "MyEnum" in names
def test_collect_names_from_function_body_isinstance() -> None:
source = """\
def f(x):
if isinstance(x, SomeClass):
pass
if issubclass(type(x), (A, B)):
pass
"""
tree = ast.parse(source)
func_node = tree.body[0]
names = collect_names_from_function_body(func_node)
assert "SomeClass" in names
assert "A" in names
assert "B" in names
def test_collect_names_from_function_body_class_method_call() -> None:
source = """\
def f():
obj = Builder.create(42)
"""
tree = ast.parse(source)
func_node = tree.body[0]
names = collect_names_from_function_body(func_node)
assert "Builder" in names
# --- Tests for extract_full_class_from_module ---
def test_extract_full_class_from_module_basic() -> None:
source = """\
class Foo:
x = 10
def bar(self):
return self.x
"""
tree = ast.parse(source)
result = extract_full_class_from_module("Foo", source, tree)
assert result is not None
assert "class Foo:" in result
assert "x = 10" in result
assert "def bar(self):" in result
def test_extract_full_class_from_module_missing_class() -> None:
source = """\
class Foo:
pass
"""
tree = ast.parse(source)
result = extract_full_class_from_module("Bar", source, tree)
assert result is None
# --- Integration tests for extract_type_context_for_testgen with body references ---
def test_extract_type_context_for_testgen_body_enum_reference(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "enums.py").write_text(
"""\
import enum
class Color(enum.Enum):
RED = "red"
BLUE = "blue"
""",
encoding="utf-8",
)
(pkg / "processor.py").write_text(
"""\
from mypkg.enums import Color
def paint(name: str) -> str:
return name + Color.RED.value
""",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="paint", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 1
code = result.code_strings[0].code
assert "class Color" in code
assert "RED" in code
def test_extract_type_context_for_testgen_body_config_class(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "config.py").write_text(
"""\
class AppConfig:
DEBUG = True
LOG_LEVEL = "INFO"
""",
encoding="utf-8",
)
(pkg / "service.py").write_text(
"""\
from mypkg.config import AppConfig
def run() -> str:
return AppConfig.LOG_LEVEL
""",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="run", file_path=(pkg / "service.py").resolve(), starting_line=3, ending_line=4
)
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 1
code = result.code_strings[0].code
assert "class AppConfig:" in code
assert "LOG_LEVEL" in code
def test_extract_type_context_for_testgen_instance_resolution(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "config.py").write_text(
"""\
class AppConfig:
DEBUG = True
LOG_LEVEL = "INFO"
app_config = AppConfig()
""",
encoding="utf-8",
)
(pkg / "service.py").write_text(
"""\
from mypkg.config import app_config
def run() -> str:
return app_config.LOG_LEVEL
""",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="run", file_path=(pkg / "service.py").resolve(), starting_line=3, ending_line=4
)
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 1
code = result.code_strings[0].code
assert "class AppConfig:" in code
def test_extract_type_context_for_testgen_non_imported_names_filtered(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "processor.py").write_text(
"""\
def process(x: int) -> int:
result = SomeLocal(x)
return result
""",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=1, ending_line=3
)
result = extract_type_context_for_testgen(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 0
def test_enrich_testgen_context_cross_module_base_class(tmp_path: Path) -> None:
"""Base classes from other modules are resolved and extracted."""
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "base.py").write_text(
"""\
class Base:
def __init__(self, x: int):
self.x = x
def do_stuff(self) -> int:
return self.x * 2
""",
encoding="utf-8",
)
(pkg / "child.py").write_text(
"""\
from mypkg.base import Base
class Child(Base):
def __init__(self, x: int, y: int):
super().__init__(x)
self.y = y
""",
encoding="utf-8",
)
context = CodeStringsMarkdown(
code_strings=[
CodeString(
code="""\
from mypkg.child import Child
def process(c: Child) -> int:
return c.do_stuff()
""",
file_path=pkg / "user.py",
)
]
)
result = enrich_testgen_context(context, tmp_path)
all_code = "\n".join(cs.code for cs in result.code_strings)
assert "class Child(Base)" in all_code
assert "class Base:" in all_code
def test_enrich_testgen_context_cross_module_base_depth_limit(tmp_path: Path) -> None:
"""Deep inheritance chains across modules stop at the depth limit without crashing."""
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
# Create a chain of 6 modules: Level0 -> Level1 -> ... -> Level5
for i in range(6):
if i == 0:
code = f"class Level{i}:\n pass\n"
else:
code = f"from mypkg.level{i - 1} import Level{i - 1}\n\nclass Level{i}(Level{i - 1}):\n pass\n"
(pkg / f"level{i}.py").write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(
code_strings=[
CodeString(
code="from mypkg.level5 import Level5\n\ndef process(obj: Level5) -> None:\n pass\n",
file_path=pkg / "user.py",
)
]
)
# Should not crash and should extract Level5 plus some ancestors (limited by depth=3)
result = enrich_testgen_context(context, tmp_path)
all_code = "\n".join(cs.code for cs in result.code_strings)
assert "class Level5" in all_code
# Depth limit prevents extracting the full chain — Level0 through Level2 may not appear
# but Level3+ should be present
assert "class Level3" in all_code or "class Level4" in all_code
def test_enrich_type_context_classes_extracts_field_types(tmp_path: Path) -> None:
"""Classes referenced as field types in type-context classes are extracted."""
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "source.py").write_text(
"""\
class Source:
def __init__(self, name: str):
self.name = name
""",
encoding="utf-8",
)
(pkg / "regions.py").write_text(
"""\
from mypkg.source import Source
class TextRegions:
sources: list[Source]
label: str
def __init__(self, sources: list[Source], label: str):
self.sources = sources
self.label = label
""",
encoding="utf-8",
)
type_context = CodeStringsMarkdown(
code_strings=[
CodeString(
code="""\
class TextRegions:
sources: list[Source]
label: str
def __init__(self, sources: list[Source], label: str):
self.sources = sources
self.label = label
""",
file_path=pkg / "regions.py",
)
]
)
result = enrich_type_context_classes(type_context, {"TextRegions"}, tmp_path)
assert len(result.code_strings) == 1
assert "class Source:" in result.code_strings[0].code
def test_enrich_type_context_classes_skips_existing_and_builtins(tmp_path: Path) -> None:
"""Builtins and already-existing class names are not re-extracted."""
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "models.py").write_text(
"""\
from mypkg.other import Other
class MyClass:
x: int
y: str
z: Other
""",
encoding="utf-8",
)
(pkg / "other.py").write_text(
"""\
class Other:
def __init__(self):
pass
""",
encoding="utf-8",
)
type_context = CodeStringsMarkdown(
code_strings=[
CodeString(
code="""\
class MyClass:
x: int
y: str
z: Other
""",
file_path=pkg / "models.py",
)
]
)
# Other is already existing — should not be extracted
result = enrich_type_context_classes(type_context, {"MyClass", "Other"}, tmp_path)
assert len(result.code_strings) == 0
# Without Other in existing_class_names, it should be extracted
result = enrich_type_context_classes(type_context, {"MyClass"}, tmp_path)
assert len(result.code_strings) == 1
assert "class Other:" in result.code_strings[0].code