mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
refactor: simplify and deduplicate code_context_extractor
Consolidate three enricher functions (get_imported_class_definitions, get_external_base_class_inits, get_external_class_inits) into a single enrich_testgen_context that parses code context once. Extract shared helpers, unify prune_cst variants, deduplicate loop bodies, and remove dead UsedNameCollector class.
This commit is contained in:
parent
d578d9969b
commit
fa00422fea
5 changed files with 479 additions and 1304 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -268,3 +268,5 @@ tessl.json
|
|||
|
||||
# Tessl auto-generates AGENTS.md on install; ignore to avoid cluttering git status
|
||||
AGENTS.md
|
||||
.serena/
|
||||
.codeflash/
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from enum import Enum
|
|||
from typing import Any, Union
|
||||
|
||||
MAX_TEST_RUN_ITERATIONS = 5
|
||||
OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 16000
|
||||
TESTGEN_CONTEXT_TOKEN_LIMIT = 16000
|
||||
OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 48000
|
||||
TESTGEN_CONTEXT_TOKEN_LIMIT = 48000
|
||||
INDIVIDUAL_TESTCASE_TIMEOUT = 15
|
||||
MAX_FUNCTION_TEST_SECONDS = 60
|
||||
MIN_IMPROVEMENT_THRESHOLD = 0.05
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -34,7 +34,7 @@ if TYPE_CHECKING:
|
|||
from codeflash.languages.base import LanguageSupport
|
||||
|
||||
# Module-level singleton for the current language
|
||||
_current_language: Language | None = None
|
||||
_current_language: Language = Language.PYTHON
|
||||
|
||||
|
||||
def current_language() -> Language:
|
||||
|
|
|
|||
|
|
@ -12,12 +12,10 @@ from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_g
|
|||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.context.code_context_extractor import (
|
||||
collect_names_from_annotation,
|
||||
enrich_testgen_context,
|
||||
extract_classes_from_type_hint,
|
||||
extract_imports_for_class,
|
||||
get_code_optimization_context,
|
||||
get_external_base_class_inits,
|
||||
get_external_class_inits,
|
||||
get_imported_class_definitions,
|
||||
resolve_transitive_type_deps,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
|
@ -769,199 +767,6 @@ class HelperClass:
|
|||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_example_class_token_limit_1(tmp_path: Path) -> None:
|
||||
docstring_filler = " ".join(
|
||||
["This is a long docstring that will be used to fill up the token limit." for _ in range(1000)]
|
||||
)
|
||||
code = f"""
|
||||
class MyClass:
|
||||
\"\"\"A class with a helper method.
|
||||
{docstring_filler}\"\"\"
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
\"\"\"A helper class for MyClass.\"\"\"
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def __repr__(self):
|
||||
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
||||
return "HelperClass" + str(self.x)
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
"""
|
||||
# Create a temporary Python file using pytest's tmp_path fixture
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=file_path.parent.resolve(),
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path().resolve(),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
|
||||
expected_read_write_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
```
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
pass
|
||||
|
||||
class HelperClass:
|
||||
def __repr__(self):
|
||||
return "HelperClass" + str(self.x)
|
||||
```
|
||||
"""
|
||||
expected_hashing_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
|
||||
def target_method(self):
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
```
|
||||
"""
|
||||
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_example_class_token_limit_2(tmp_path: Path) -> None:
|
||||
string_filler = " ".join(
|
||||
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
|
||||
)
|
||||
code = f"""
|
||||
class MyClass:
|
||||
\"\"\"A class with a helper method. \"\"\"
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
x = '{string_filler}'
|
||||
|
||||
class HelperClass:
|
||||
\"\"\"A helper class for MyClass.\"\"\"
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def __repr__(self):
|
||||
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
||||
return "HelperClass" + str(self.x)
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
"""
|
||||
# Create a temporary Python file using pytest's tmp_path fixture
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=file_path.parent.resolve(),
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path().resolve(),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
|
||||
expected_read_write_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
```
|
||||
"""
|
||||
expected_read_only_context = f'''```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
"""A class with a helper method. """
|
||||
|
||||
class HelperClass:
|
||||
"""A helper class for MyClass."""
|
||||
def __repr__(self):
|
||||
"""Return a string representation of the HelperClass."""
|
||||
return "HelperClass" + str(self.x)
|
||||
```
|
||||
'''
|
||||
expected_hashing_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
|
||||
def target_method(self):
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
```
|
||||
"""
|
||||
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_example_class_token_limit_3(tmp_path: Path) -> None:
|
||||
string_filler = " ".join(
|
||||
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
|
||||
|
|
@ -1009,7 +814,7 @@ class HelperClass:
|
|||
)
|
||||
# In this scenario, the read-writable code is too long, so we abort.
|
||||
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
get_code_optimization_context(function_to_optimize, opt.args.project_root, optim_token_limit=8000)
|
||||
|
||||
|
||||
def test_example_class_token_limit_4(tmp_path: Path) -> None:
|
||||
|
|
@ -1062,7 +867,7 @@ class HelperClass:
|
|||
|
||||
# In this scenario, the read-writable code context becomes too large because the __init__ function is referencing the global x variable instead of the class attribute self.x, so we abort.
|
||||
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
get_code_optimization_context(function_to_optimize, opt.args.project_root, optim_token_limit=8000)
|
||||
|
||||
|
||||
def test_example_class_token_limit_5(tmp_path: Path) -> None:
|
||||
|
|
@ -2422,7 +2227,7 @@ class OuterClass:
|
|||
assert "__init__" not in hashing_context # Should not contain __init__ methods
|
||||
|
||||
# Verify nested classes are excluded from the hashing context
|
||||
# The prune_cst_for_code_hashing function should not recurse into nested classes
|
||||
# The prune_cst function in hashing mode should not recurse into nested classes
|
||||
assert "class NestedClass:" not in hashing_context # Nested class definition should not be present
|
||||
|
||||
# The target method will reference NestedClass, but the actual nested class definition should not be included
|
||||
|
|
@ -3275,8 +3080,8 @@ def dump_layout(layout_type, layout):
|
|||
assert testgen_context.count("def __init__") >= 2, "Both __init__ methods should be in testgen context"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions extracts class definitions from project modules."""
|
||||
def test_enrich_testgen_context_extracts_project_classes(tmp_path: Path) -> None:
|
||||
"""Test that enrich_testgen_context extracts class definitions from project modules."""
|
||||
# Create a package structure with two modules
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3325,8 +3130,8 @@ class Accumulator:
|
|||
# Create CodeStringsMarkdown from the chunking module (simulating testgen context)
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=chunking_code, file_path=chunking_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Verify Element class was extracted
|
||||
assert len(result.code_strings) == 1, "Should extract exactly one class (Element)"
|
||||
|
|
@ -3339,8 +3144,8 @@ class Accumulator:
|
|||
assert "import abc" in extracted_code, "Should include necessary imports for base class"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_skips_existing_definitions(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions skips classes already defined in context."""
|
||||
def test_enrich_testgen_context_skips_existing_definitions(tmp_path: Path) -> None:
|
||||
"""Test that enrich_testgen_context skips classes already defined in context."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3373,15 +3178,15 @@ class User:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code_with_local_def, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should NOT extract Element since it's already defined locally
|
||||
assert len(result.code_strings) == 0, "Should not extract classes already defined in context"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_skips_third_party(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions skips third-party/stdlib imports."""
|
||||
def test_enrich_testgen_context_skips_third_party(tmp_path: Path) -> None:
|
||||
"""Test that enrich_testgen_context skips third-party/stdlib imports."""
|
||||
# Create a simple package
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3402,15 +3207,15 @@ class MyClass:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should not extract any classes (Path, Optional, dataclass are stdlib/third-party)
|
||||
assert len(result.code_strings) == 0, "Should not extract stdlib/third-party classes"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_handles_multiple_imports(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions handles multiple class imports."""
|
||||
def test_enrich_testgen_context_handles_multiple_imports(tmp_path: Path) -> None:
|
||||
"""Test that enrich_testgen_context handles multiple class imports."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3446,8 +3251,8 @@ class Processor:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract both TypeA and TypeB (but not TypeC since it's not imported)
|
||||
assert len(result.code_strings) == 2, "Should extract exactly two classes (TypeA, TypeB)"
|
||||
|
|
@ -3458,8 +3263,8 @@ class Processor:
|
|||
assert "class TypeC" not in all_extracted_code, "Should NOT contain TypeC (not imported)"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_includes_dataclass_decorators(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions includes decorators when extracting dataclasses."""
|
||||
def test_enrich_testgen_context_includes_dataclass_decorators(tmp_path: Path) -> None:
|
||||
"""Test that enrich_testgen_context includes decorators when extracting dataclasses."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3496,8 +3301,8 @@ class ConfigRegistry:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract both LLMConfigBase (base class) and LLMConfig
|
||||
assert len(result.code_strings) == 2, "Should extract both LLMConfig and its base class LLMConfigBase"
|
||||
|
|
@ -3521,7 +3326,7 @@ class ConfigRegistry:
|
|||
assert "from dataclasses import" in all_extracted_code, "Should include dataclasses import"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_extracts_imports_for_decorated_classes(tmp_path: Path) -> None:
|
||||
"""Test that extract_imports_for_class includes decorator and type annotation imports."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
|
|
@ -3552,7 +3357,7 @@ def create_config() -> Config:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1, "Should extract Config class"
|
||||
extracted_code = result.code_strings[0].code
|
||||
|
|
@ -3724,7 +3529,7 @@ class MyClass:
|
|||
assert result.count("from typing import Optional") == 1
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_multiple_decorators(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_multiple_decorators(tmp_path: Path) -> None:
|
||||
"""Test that classes with multiple decorators are extracted correctly."""
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3755,7 +3560,7 @@ def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
extracted_code = result.code_strings[0].code
|
||||
|
|
@ -3766,7 +3571,7 @@ def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
|
|||
assert "class OrderedConfig" in extracted_code
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_extracts_multilevel_inheritance(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_extracts_multilevel_inheritance(tmp_path: Path) -> None:
|
||||
"""Test that base classes are recursively extracted for multi-level inheritance.
|
||||
|
||||
This is critical for understanding dataclass constructor signatures, as fields
|
||||
|
|
@ -3826,8 +3631,8 @@ class ConfigRegistry:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract 4 classes: GrandParentConfig, ParentConfig, ChildConfig, RouterConfig
|
||||
# (all classes needed to understand the full inheritance hierarchy)
|
||||
|
|
@ -3862,7 +3667,7 @@ class ConfigRegistry:
|
|||
assert "model_list: list" in all_extracted_code, "Should include model_list field from Router"
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_extracts_userdict(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_extracts_userdict(tmp_path: Path) -> None:
|
||||
"""Extracts __init__ from collections.UserDict when a class inherits from it."""
|
||||
code = """from collections import UserDict
|
||||
|
||||
|
|
@ -3873,7 +3678,7 @@ class MyCustomDict(UserDict):
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
code_string = result.code_strings[0]
|
||||
|
|
@ -3891,8 +3696,8 @@ class UserDict:
|
|||
assert code_string.file_path.as_posix().endswith("collections/__init__.py")
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_skips_project_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty when base class is from the project, not external."""
|
||||
def test_enrich_testgen_context_skips_unresolvable_base_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty when base class module cannot be resolved."""
|
||||
child_code = """from base import ProjectBase
|
||||
|
||||
class Child(ProjectBase):
|
||||
|
|
@ -3902,12 +3707,12 @@ class Child(ProjectBase):
|
|||
child_path.write_text(child_code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=child_code, file_path=child_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_skips_builtin_base_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty for builtin classes like list that have no inspectable source."""
|
||||
code = """class MyList(list):
|
||||
pass
|
||||
|
|
@ -3916,12 +3721,12 @@ def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_deduplicates(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_deduplicates(tmp_path: Path) -> None:
|
||||
"""Extracts the same external base class only once even when inherited multiple times."""
|
||||
code = """from collections import UserDict
|
||||
|
||||
|
|
@ -3935,7 +3740,7 @@ class MyDict2(UserDict):
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
expected_code = """\
|
||||
|
|
@ -3950,7 +3755,7 @@ class UserDict:
|
|||
assert result.code_strings[0].code == expected_code
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_empty_when_no_inheritance(tmp_path: Path) -> None:
|
||||
"""Returns empty when there are no external base classes."""
|
||||
code = """class SimpleClass:
|
||||
pass
|
||||
|
|
@ -3959,7 +3764,7 @@ def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path)
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
|
@ -4103,127 +3908,8 @@ class MyCustomDict(UserDict):
|
|||
assert "self.data = {}" in testgen_context, "UserDict __init__ body should be included"
|
||||
|
||||
|
||||
def test_read_only_code_removed_when_exceeds_limit(tmp_path: Path) -> None:
|
||||
"""Test read-only code is completely removed when it exceeds token limit even without docstrings.
|
||||
|
||||
This covers lines 152-153 in code_context_extractor.py where read_only_context_code is set
|
||||
to empty string when it still exceeds the token limit after docstring removal.
|
||||
"""
|
||||
# Create a second-degree helper with large implementation that has no docstrings
|
||||
# Second-degree helpers go into read-only context
|
||||
long_lines = [" x = 0"]
|
||||
for i in range(150):
|
||||
long_lines.append(f" x = x + {i}")
|
||||
long_lines.append(" return x")
|
||||
long_body = "\n".join(long_lines)
|
||||
|
||||
code = f"""
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def target_method(self):
|
||||
return first_helper()
|
||||
|
||||
|
||||
def first_helper():
|
||||
# First degree helper - calls second degree
|
||||
return second_helper()
|
||||
|
||||
|
||||
def second_helper():
|
||||
# Second degree helper - goes into read-only context
|
||||
{long_body}
|
||||
"""
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")]
|
||||
)
|
||||
|
||||
# Use a small optim_token_limit that allows read-writable but not read-only
|
||||
# Read-writable is ~48 tokens, read-only is ~600 tokens
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
optim_token_limit=100, # Small limit to trigger read-only removal
|
||||
)
|
||||
|
||||
# The read-only context should be empty because it exceeded the limit
|
||||
assert code_ctx.read_only_context_code == "", "Read-only code should be removed when exceeding token limit"
|
||||
|
||||
|
||||
def test_testgen_removes_imported_classes_on_overflow(tmp_path: Path) -> None:
|
||||
"""Test testgen context removes imported class definitions when exceeding token limit.
|
||||
|
||||
This covers lines 176-186 in code_context_extractor.py where:
|
||||
- Testgen context exceeds limit (line 175)
|
||||
- Removing docstrings still exceeds (line 175 again)
|
||||
- Removing imported classes succeeds (line 177-183)
|
||||
"""
|
||||
# Create a package structure with a large type class used only in type annotations
|
||||
# This ensures get_imported_class_definitions extracts the full class
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# Create a large class with methods that will be extracted via get_imported_class_definitions
|
||||
# Use methods WITHOUT docstrings so removing docstrings won't help much
|
||||
many_methods = "\n".join([f" def method_{i}(self):\n return {i}" for i in range(100)])
|
||||
type_class_code = f'''
|
||||
class TypeClass:
|
||||
"""A type class for annotations."""
|
||||
|
||||
def __init__(self, value: int):
|
||||
self.value = value
|
||||
|
||||
{many_methods}
|
||||
'''
|
||||
type_class_path = package_dir / "types.py"
|
||||
type_class_path.write_text(type_class_code, encoding="utf-8")
|
||||
|
||||
# Main module uses TypeClass only in annotation (not instantiated)
|
||||
# This triggers get_imported_class_definitions to extract the full class
|
||||
main_code = """
|
||||
from mypackage.types import TypeClass
|
||||
|
||||
def target_function(obj: TypeClass) -> int:
|
||||
return obj.value
|
||||
"""
|
||||
main_path = package_dir / "main.py"
|
||||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(function_name="target_function", file_path=main_path, parents=[])
|
||||
|
||||
# Use a testgen_token_limit that:
|
||||
# - Is exceeded by full context with imported class (~1500 tokens)
|
||||
# - Is exceeded even after removing docstrings
|
||||
# - But fits when imported class is removed (~40 tokens)
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
testgen_token_limit=200, # Small limit to trigger imported class removal
|
||||
)
|
||||
|
||||
# The testgen context should exist (didn't raise ValueError)
|
||||
testgen_context = code_ctx.testgen_context.markdown
|
||||
assert testgen_context, "Testgen context should not be empty"
|
||||
|
||||
# The target function should still be there
|
||||
assert "def target_function" in testgen_context, "Target function should be in testgen context"
|
||||
|
||||
# The large imported class should NOT be included (removed due to token limit)
|
||||
assert "class TypeClass" not in testgen_context, (
|
||||
"TypeClass should be removed from testgen context when exceeding token limit"
|
||||
)
|
||||
|
||||
|
||||
def test_testgen_raises_when_all_fallbacks_fail(tmp_path: Path) -> None:
|
||||
"""Test that ValueError is raised when testgen context exceeds limit even after all fallbacks.
|
||||
|
||||
This covers line 186 in code_context_extractor.py.
|
||||
"""
|
||||
def test_testgen_raises_when_exceeds_limit(tmp_path: Path) -> None:
|
||||
"""Test that ValueError is raised when testgen context exceeds token limit."""
|
||||
# Create a function with a very long body that exceeds limits even without imports/docstrings
|
||||
long_lines = [" x = 0"]
|
||||
for i in range(200):
|
||||
|
|
@ -4249,7 +3935,7 @@ def target_function():
|
|||
)
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_attribute_base(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_attribute_base(tmp_path: Path) -> None:
|
||||
"""Test handling of base class accessed as module.ClassName (ast.Attribute).
|
||||
|
||||
This covers line 616 in code_context_extractor.py.
|
||||
|
|
@ -4265,7 +3951,7 @@ class MyDict(UserDict):
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract UserDict __init__
|
||||
assert len(result.code_strings) == 1
|
||||
|
|
@ -4273,7 +3959,7 @@ class MyDict(UserDict):
|
|||
assert "def __init__" in result.code_strings[0].code
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_no_init_method(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_no_init_method(tmp_path: Path) -> None:
|
||||
"""Test handling when base class has no __init__ method.
|
||||
|
||||
This covers line 641 in code_context_extractor.py.
|
||||
|
|
@ -4288,7 +3974,7 @@ class MyProtocol(Protocol):
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Protocol's __init__ can't be easily inspected, should handle gracefully
|
||||
# Result may be empty or contain Protocol based on implementation
|
||||
|
|
@ -4377,7 +4063,7 @@ class MyClass:
|
|||
|
||||
|
||||
def test_imported_class_definitions_module_path_none(tmp_path: Path) -> None:
|
||||
"""Test handling when module_path is None in get_imported_class_definitions.
|
||||
"""Test handling when module_path is None in enrich_testgen_context.
|
||||
|
||||
This covers line 560 in code_context_extractor.py.
|
||||
"""
|
||||
|
|
@ -4393,123 +4079,12 @@ class MyClass:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should handle gracefully and return empty or partial results
|
||||
assert isinstance(result.code_strings, list)
|
||||
|
||||
|
||||
def test_get_imported_names_import_star(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles import * correctly.
|
||||
|
||||
This covers lines 808-809 and 824-825 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
# Test regular import *
|
||||
# Note: "import *" is not valid Python, but "from x import *" is
|
||||
from_import_star = cst.parse_statement("from os import *")
|
||||
assert isinstance(from_import_star, cst.SimpleStatementLine)
|
||||
import_node = from_import_star.body[0]
|
||||
assert isinstance(import_node, cst.ImportFrom)
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert result == {"*"}
|
||||
|
||||
|
||||
def test_get_imported_names_aliased_import(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles aliased imports correctly.
|
||||
|
||||
This covers lines 812-813 and 828-829 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test import with alias
|
||||
import_stmt = cst.parse_statement("import numpy as np")
|
||||
assert isinstance(import_stmt, cst.SimpleStatementLine)
|
||||
import_node = import_stmt.body[0]
|
||||
assert isinstance(import_node, cst.Import)
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert "np" in result
|
||||
|
||||
# Test from import with alias
|
||||
from_import_stmt = cst.parse_statement("from os import path as ospath")
|
||||
assert isinstance(from_import_stmt, cst.SimpleStatementLine)
|
||||
from_import_node = from_import_stmt.body[0]
|
||||
assert isinstance(from_import_node, cst.ImportFrom)
|
||||
|
||||
result2 = get_imported_names(from_import_node)
|
||||
assert "ospath" in result2
|
||||
|
||||
|
||||
def test_get_imported_names_dotted_import(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles dotted imports correctly.
|
||||
|
||||
This covers lines 816-822 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test dotted import like "import os.path"
|
||||
import_stmt = cst.parse_statement("import os.path")
|
||||
assert isinstance(import_stmt, cst.SimpleStatementLine)
|
||||
import_node = import_stmt.body[0]
|
||||
assert isinstance(import_node, cst.Import)
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert "os" in result
|
||||
|
||||
|
||||
def test_used_name_collector_comprehensive(tmp_path: Path) -> None:
|
||||
"""Test UsedNameCollector handles various node types.
|
||||
|
||||
This covers lines 767-801 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import UsedNameCollector
|
||||
|
||||
code = """
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
x: int = 1
|
||||
y = os.path.join("a", "b")
|
||||
|
||||
class MyClass:
|
||||
z = 10
|
||||
|
||||
def my_func():
|
||||
pass
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
collector = UsedNameCollector()
|
||||
# In libcst, the walker traverses the module
|
||||
cst.MetadataWrapper(module).visit(collector)
|
||||
|
||||
# Check used names
|
||||
assert "os" in collector.used_names
|
||||
assert "int" in collector.used_names
|
||||
assert "List" in collector.used_names
|
||||
|
||||
# Check defined names
|
||||
assert "x" in collector.defined_names
|
||||
assert "y" in collector.defined_names
|
||||
assert "MyClass" in collector.defined_names
|
||||
assert "my_func" in collector.defined_names
|
||||
|
||||
# Check external names (used but not defined)
|
||||
external = collector.get_external_names()
|
||||
assert "os" in external
|
||||
assert "x" not in external # x is defined
|
||||
|
||||
|
||||
def test_imported_class_with_base_in_same_module(tmp_path: Path) -> None:
|
||||
"""Test that imported classes with bases in the same module are extracted correctly.
|
||||
|
||||
|
|
@ -4549,52 +4124,13 @@ def target_function(obj: DerivedClass) -> bool:
|
|||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=main_code, file_path=main_path)])
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract the inheritance chain
|
||||
all_code = "\n".join(cs.code for cs in result.code_strings)
|
||||
assert "class BaseClass" in all_code or "class DerivedClass" in all_code
|
||||
|
||||
|
||||
def test_get_imported_names_from_import_without_alias(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles from imports without aliases.
|
||||
|
||||
This covers lines 830-831 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test from import without alias
|
||||
from_import_stmt = cst.parse_statement("from os import path, getcwd")
|
||||
assert isinstance(from_import_stmt, cst.SimpleStatementLine)
|
||||
from_import_node = from_import_stmt.body[0]
|
||||
assert isinstance(from_import_node, cst.ImportFrom)
|
||||
|
||||
result = get_imported_names(from_import_node)
|
||||
assert "path" in result
|
||||
assert "getcwd" in result
|
||||
|
||||
|
||||
def test_get_imported_names_regular_import(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles regular imports.
|
||||
|
||||
This covers lines 814-815 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test regular import without alias
|
||||
import_stmt = cst.parse_statement("import json")
|
||||
assert isinstance(import_stmt, cst.SimpleStatementLine)
|
||||
import_node = import_stmt.body[0]
|
||||
assert isinstance(import_node, cst.Import)
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert "json" in result
|
||||
|
||||
|
||||
def test_augmented_assignment_not_in_context(tmp_path: Path) -> None:
|
||||
"""Test that augmented assignments are handled but not included unless used.
|
||||
|
||||
|
|
@ -4625,7 +4161,7 @@ class MyClass:
|
|||
assert "counter" in read_writable
|
||||
|
||||
|
||||
def test_get_external_class_inits_extracts_click_option(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_extracts_click_option(tmp_path: Path) -> None:
|
||||
"""Extracts __init__ from click.Option when directly imported."""
|
||||
code = """from click import Option
|
||||
|
||||
|
|
@ -4636,7 +4172,7 @@ def my_func(opt: Option) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
code_string = result.code_strings[0]
|
||||
|
|
@ -4645,8 +4181,8 @@ def my_func(opt: Option) -> None:
|
|||
assert code_string.file_path is not None and "click" in code_string.file_path.as_posix()
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_project_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty when imported class is from the project, not external."""
|
||||
def test_enrich_testgen_context_extracts_project_class_defs(tmp_path: Path) -> None:
|
||||
"""Extracts project class definitions via jedi resolution."""
|
||||
# Create a project module with a class
|
||||
(tmp_path / "mymodule.py").write_text("class ProjectClass:\n pass\n", encoding="utf-8")
|
||||
|
||||
|
|
@ -4659,12 +4195,13 @@ def my_func(obj: ProjectClass) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
assert len(result.code_strings) == 1
|
||||
assert "class ProjectClass" in result.code_strings[0].code
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_non_classes(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_skips_non_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty when imported name is a function, not a class."""
|
||||
code = """from collections import OrderedDict
|
||||
from os.path import join
|
||||
|
|
@ -4676,7 +4213,7 @@ def my_func() -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# join is a function, not a class — should be skipped
|
||||
# OrderedDict is a class and should be included
|
||||
|
|
@ -4684,8 +4221,8 @@ def my_func() -> None:
|
|||
assert not any("join" in name for name in class_names)
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_already_defined_classes(tmp_path: Path) -> None:
|
||||
"""Skips classes already defined in the context (e.g., added by get_imported_class_definitions)."""
|
||||
def test_enrich_testgen_context_skips_already_defined_classes(tmp_path: Path) -> None:
|
||||
"""Skips classes already defined in the context (e.g., added by enrich_testgen_context)."""
|
||||
code = """from collections import UserDict
|
||||
|
||||
class UserDict:
|
||||
|
|
@ -4699,14 +4236,14 @@ def my_func(d: UserDict) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# UserDict is already defined in the context, so it should be skipped
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_builtins(tmp_path: Path) -> None:
|
||||
"""Returns empty for builtin classes like list/dict that have no inspectable source."""
|
||||
def test_enrich_testgen_context_skips_builtin_annotations(tmp_path: Path) -> None:
|
||||
"""Returns empty for builtin type annotations like list/dict that are not imported."""
|
||||
code = """x: list = []
|
||||
y: dict = {}
|
||||
|
||||
|
|
@ -4717,12 +4254,12 @@ def my_func() -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_object_init(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_skips_object_init(tmp_path: Path) -> None:
|
||||
"""Skips classes whose __init__ is just object.__init__ (trivial)."""
|
||||
# enum.Enum has a metaclass-based __init__, but individual enum members
|
||||
# effectively use object.__init__. Use a class we know has object.__init__.
|
||||
|
|
@ -4735,14 +4272,14 @@ def my_func(q: QName) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# QName has its own __init__, so it should be included if it's in site-packages.
|
||||
# But since it's stdlib (not site-packages), it should be skipped.
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_empty_when_no_imports(tmp_path: Path) -> None:
|
||||
"""Returns empty when there are no from-imports."""
|
||||
code = """def my_func() -> None:
|
||||
pass
|
||||
|
|
@ -4751,7 +4288,7 @@ def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
|
@ -4840,17 +4377,17 @@ def test_resolve_transitive_type_deps_handles_failure_gracefully() -> None:
|
|||
"""Returns empty list for a class where get_type_hints fails."""
|
||||
|
||||
class BadClass:
|
||||
def __init__(self, x: "NonexistentType") -> None: # type: ignore[name-defined] # noqa: F821
|
||||
def __init__(self, x: NonexistentType) -> None: # type: ignore[name-defined] # noqa: F821
|
||||
pass
|
||||
|
||||
result = resolve_transitive_type_deps(BadClass)
|
||||
assert result == []
|
||||
|
||||
|
||||
# --- Integration tests for transitive resolution in get_external_class_inits ---
|
||||
# --- Integration tests for transitive resolution in enrich_testgen_context ---
|
||||
|
||||
|
||||
def test_get_external_class_inits_transitive_deps(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_transitive_deps(tmp_path: Path) -> None:
|
||||
"""Extracts transitive type dependencies from __init__ annotations."""
|
||||
code = """from click import Context
|
||||
|
||||
|
|
@ -4861,7 +4398,7 @@ def my_func(ctx: Context) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
class_names = {cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings}
|
||||
assert "Context" in class_names
|
||||
|
|
@ -4869,7 +4406,7 @@ def my_func(ctx: Context) -> None:
|
|||
assert "Command" in class_names
|
||||
|
||||
|
||||
def test_get_external_class_inits_no_infinite_loops(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_no_infinite_loops(tmp_path: Path) -> None:
|
||||
"""Handles classes with circular type references without infinite loops."""
|
||||
# click.Context references Command, and Command references Context back
|
||||
# This should terminate without issues due to the processed_classes set
|
||||
|
|
@ -4882,13 +4419,13 @@ def my_func(ctx: Context) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should complete without hanging; just verify we got results
|
||||
assert len(result.code_strings) >= 1
|
||||
|
||||
|
||||
def test_get_external_class_inits_no_duplicate_stubs(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_no_duplicate_stubs(tmp_path: Path) -> None:
|
||||
"""Does not emit duplicate stubs for the same class name."""
|
||||
code = """from click import Context
|
||||
|
||||
|
|
@ -4899,7 +4436,7 @@ def my_func(ctx: Context) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
class_names = [cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings]
|
||||
assert len(class_names) == len(set(class_names)), f"Duplicate class stubs found: {class_names}"
|
||||
|
|
|
|||
Loading…
Reference in a new issue