Merge branch 'main' into call-graphee
This commit is contained in:
commit
0bcc483a95
22 changed files with 660 additions and 1342 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -268,3 +268,5 @@ tessl.json
|
|||
|
||||
# Tessl auto-generates AGENTS.md on install; ignore to avoid cluttering git status
|
||||
AGENTS.md
|
||||
.serena/
|
||||
.codeflash/
|
||||
|
|
|
|||
|
|
@ -557,15 +557,6 @@ class LanguageSupport(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def get_comment_prefix(self) -> str:
|
||||
"""Get the comment prefix for this language.
|
||||
|
||||
Returns:
|
||||
Comment prefix (e.g., "//" for JS, "#" for Python).
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def find_test_root(self, project_root: Path) -> Path | None:
|
||||
"""Find the test root directory for a project.
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ if TYPE_CHECKING:
|
|||
from codeflash.languages.base import LanguageSupport
|
||||
|
||||
# Module-level singleton for the current language
|
||||
_current_language: Language | None = None
|
||||
_current_language: Language = Language.PYTHON
|
||||
|
||||
|
||||
def current_language() -> Language:
|
||||
|
|
|
|||
|
|
@ -1805,15 +1805,6 @@ class JavaScriptSupport:
|
|||
"""
|
||||
return ".test.js"
|
||||
|
||||
def get_comment_prefix(self) -> str:
|
||||
"""Get the comment prefix for JavaScript.
|
||||
|
||||
Returns:
|
||||
JavaScript single-line comment prefix.
|
||||
|
||||
"""
|
||||
return "//"
|
||||
|
||||
def find_test_root(self, project_root: Path) -> Path | None:
|
||||
"""Find the test root directory for a JavaScript project.
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -15,6 +15,8 @@ from codeflash.languages import is_javascript
|
|||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext, FunctionSource
|
||||
|
||||
|
|
@ -49,6 +51,69 @@ def extract_names_from_targets(target: cst.CSTNode) -> list[str]:
|
|||
return names
|
||||
|
||||
|
||||
def is_assignment_used(node: cst.CSTNode, definitions: dict[str, UsageInfo], name_prefix: str = "") -> bool:
|
||||
if isinstance(node, cst.Assign):
|
||||
for target in node.targets:
|
||||
names = extract_names_from_targets(target.target)
|
||||
for name in names:
|
||||
lookup = f"{name_prefix}{name}" if name_prefix else name
|
||||
if lookup in definitions and definitions[lookup].used_by_qualified_function:
|
||||
return True
|
||||
return False
|
||||
if isinstance(node, (cst.AnnAssign, cst.AugAssign)):
|
||||
names = extract_names_from_targets(node.target)
|
||||
for name in names:
|
||||
lookup = f"{name_prefix}{name}" if name_prefix else name
|
||||
if lookup in definitions and definitions[lookup].used_by_qualified_function:
|
||||
return True
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def recurse_sections(
|
||||
node: cst.CSTNode,
|
||||
section_names: list[str],
|
||||
prune_fn: Callable[[cst.CSTNode], tuple[cst.CSTNode | None, bool]],
|
||||
keep_non_target_children: bool = False,
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {}
|
||||
found_any_target = False
|
||||
for section in section_names:
|
||||
original_content = getattr(node, section, None)
|
||||
if isinstance(original_content, (list, tuple)):
|
||||
new_children = []
|
||||
section_found_target = False
|
||||
for child in original_content:
|
||||
filtered, found_target = prune_fn(child)
|
||||
if filtered:
|
||||
new_children.append(filtered)
|
||||
section_found_target |= found_target
|
||||
if keep_non_target_children:
|
||||
if section_found_target or new_children:
|
||||
found_any_target |= section_found_target
|
||||
updates[section] = new_children
|
||||
elif section_found_target:
|
||||
found_any_target = True
|
||||
updates[section] = new_children
|
||||
elif original_content is not None:
|
||||
filtered, found_target = prune_fn(original_content)
|
||||
if keep_non_target_children:
|
||||
found_any_target |= found_target
|
||||
if filtered:
|
||||
updates[section] = filtered
|
||||
elif found_target:
|
||||
found_any_target = True
|
||||
if filtered:
|
||||
updates[section] = filtered
|
||||
if keep_non_target_children:
|
||||
if updates:
|
||||
return node.with_changes(**updates), found_any_target
|
||||
return None, False
|
||||
if not found_any_target:
|
||||
return None, False
|
||||
return (node.with_changes(**updates) if updates else node), True
|
||||
|
||||
|
||||
def collect_top_level_definitions(
|
||||
node: cst.CSTNode, definitions: Optional[dict[str, UsageInfo]] = None
|
||||
) -> dict[str, UsageInfo]:
|
||||
|
|
@ -423,27 +488,9 @@ def remove_unused_definitions_recursively(
|
|||
elif isinstance(statement, (cst.Assign, cst.AnnAssign, cst.AugAssign)):
|
||||
var_used = False
|
||||
|
||||
# Check if any variable in this assignment is used
|
||||
if isinstance(statement, cst.Assign):
|
||||
for target in statement.targets:
|
||||
names = extract_names_from_targets(target.target)
|
||||
for name in names:
|
||||
class_var_name = f"{class_name}.{name}"
|
||||
if (
|
||||
class_var_name in definitions
|
||||
and definitions[class_var_name].used_by_qualified_function
|
||||
):
|
||||
var_used = True
|
||||
method_or_var_used = True
|
||||
break
|
||||
elif isinstance(statement, (cst.AnnAssign, cst.AugAssign)):
|
||||
names = extract_names_from_targets(statement.target)
|
||||
for name in names:
|
||||
class_var_name = f"{class_name}.{name}"
|
||||
if class_var_name in definitions and definitions[class_var_name].used_by_qualified_function:
|
||||
var_used = True
|
||||
method_or_var_used = True
|
||||
break
|
||||
if is_assignment_used(statement, definitions, name_prefix=f"{class_name}."):
|
||||
var_used = True
|
||||
method_or_var_used = True
|
||||
|
||||
if var_used or class_has_dependencies:
|
||||
new_statements.append(statement)
|
||||
|
|
@ -459,56 +506,19 @@ def remove_unused_definitions_recursively(
|
|||
|
||||
return node, method_or_var_used or class_has_dependencies
|
||||
|
||||
# Handle assignments (Assign and AnnAssign)
|
||||
if isinstance(node, cst.Assign):
|
||||
for target in node.targets:
|
||||
names = extract_names_from_targets(target.target)
|
||||
for name in names:
|
||||
if name in definitions and definitions[name].used_by_qualified_function:
|
||||
return node, True
|
||||
return None, False
|
||||
|
||||
if isinstance(node, (cst.AnnAssign, cst.AugAssign)):
|
||||
names = extract_names_from_targets(node.target)
|
||||
for name in names:
|
||||
if name in definitions and definitions[name].used_by_qualified_function:
|
||||
return node, True
|
||||
# Handle assignments (Assign, AnnAssign, AugAssign)
|
||||
if isinstance(node, (cst.Assign, cst.AnnAssign, cst.AugAssign)):
|
||||
if is_assignment_used(node, definitions):
|
||||
return node, True
|
||||
return None, False
|
||||
|
||||
# For other nodes, recursively process children
|
||||
section_names = get_section_names(node)
|
||||
if not section_names:
|
||||
return node, False
|
||||
|
||||
updates = {}
|
||||
found_used = False
|
||||
|
||||
for section in section_names:
|
||||
original_content = getattr(node, section, None)
|
||||
if isinstance(original_content, (list, tuple)):
|
||||
new_children = []
|
||||
section_found_used = False
|
||||
|
||||
for child in original_content:
|
||||
filtered, used = remove_unused_definitions_recursively(child, definitions)
|
||||
if filtered:
|
||||
new_children.append(filtered)
|
||||
section_found_used |= used
|
||||
|
||||
if new_children or section_found_used:
|
||||
found_used |= section_found_used
|
||||
updates[section] = new_children
|
||||
elif original_content is not None:
|
||||
filtered, used = remove_unused_definitions_recursively(original_content, definitions)
|
||||
found_used |= used
|
||||
if filtered:
|
||||
updates[section] = filtered
|
||||
if not found_used:
|
||||
return None, False
|
||||
if updates:
|
||||
return node.with_changes(**updates), found_used
|
||||
|
||||
return node, False
|
||||
return recurse_sections(
|
||||
node, section_names, lambda child: remove_unused_definitions_recursively(child, definitions)
|
||||
)
|
||||
|
||||
|
||||
def collect_top_level_defs_with_usages(
|
||||
|
|
@ -22,10 +22,25 @@ if TYPE_CHECKING:
|
|||
from collections.abc import Sequence
|
||||
|
||||
from codeflash.languages.base import DependencyResolver
|
||||
from codeflash.models.models import FunctionSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def function_sources_to_helpers(sources: list[FunctionSource]) -> list[HelperFunction]:
|
||||
return [
|
||||
HelperFunction(
|
||||
name=fs.only_function_name,
|
||||
qualified_name=fs.qualified_name,
|
||||
file_path=fs.file_path,
|
||||
source_code=fs.source_code,
|
||||
start_line=fs.jedi_definition.line if fs.jedi_definition else 1,
|
||||
end_line=fs.jedi_definition.line if fs.jedi_definition else 1,
|
||||
)
|
||||
for fs in sources
|
||||
]
|
||||
|
||||
|
||||
@register_language
|
||||
class PythonSupport:
|
||||
"""Python language support implementation.
|
||||
|
|
@ -173,127 +188,39 @@ class PythonSupport:
|
|||
# === Code Analysis ===
|
||||
|
||||
def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext:
|
||||
"""Extract function code and its dependencies.
|
||||
"""Extract function code and its dependencies via the canonical context pipeline."""
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
Uses jedi and libcst for Python code analysis.
|
||||
|
||||
Args:
|
||||
function: The function to extract context for.
|
||||
project_root: Root of the project.
|
||||
module_root: Root of the module containing the function.
|
||||
|
||||
Returns:
|
||||
CodeContext with target code and dependencies.
|
||||
|
||||
"""
|
||||
try:
|
||||
source = function.file_path.read_text()
|
||||
result = get_code_optimization_context(function, project_root)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to read %s: %s", function.file_path, e)
|
||||
logger.warning("Failed to extract code context for %s: %s", function.function_name, e)
|
||||
return CodeContext(target_code="", target_file=function.file_path, language=Language.PYTHON)
|
||||
|
||||
# Extract the function source
|
||||
lines = source.splitlines(keepends=True)
|
||||
if function.starting_line and function.ending_line:
|
||||
target_lines = lines[function.starting_line - 1 : function.ending_line]
|
||||
target_code = "".join(target_lines)
|
||||
else:
|
||||
target_code = ""
|
||||
|
||||
# Find helper functions
|
||||
helpers = self.find_helper_functions(function, project_root)
|
||||
|
||||
# Extract imports
|
||||
import_lines = []
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
if stripped.startswith(("import ", "from ")):
|
||||
import_lines.append(stripped)
|
||||
elif stripped and not stripped.startswith("#"):
|
||||
# Stop at first non-import, non-comment line
|
||||
break
|
||||
helpers = function_sources_to_helpers(result.helper_functions)
|
||||
|
||||
return CodeContext(
|
||||
target_code=target_code,
|
||||
target_code=result.read_writable_code.markdown,
|
||||
target_file=function.file_path,
|
||||
helper_functions=helpers,
|
||||
read_only_context="",
|
||||
imports=import_lines,
|
||||
read_only_context=result.read_only_context_code,
|
||||
imports=[],
|
||||
language=Language.PYTHON,
|
||||
)
|
||||
|
||||
def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]:
|
||||
"""Find helper functions called by the target function.
|
||||
|
||||
Uses jedi for Python code analysis.
|
||||
|
||||
Args:
|
||||
function: The target function to analyze.
|
||||
project_root: Root of the project.
|
||||
|
||||
Returns:
|
||||
List of HelperFunction objects.
|
||||
|
||||
"""
|
||||
helpers: list[HelperFunction] = []
|
||||
"""Find helper functions called by the target function via the canonical jedi pipeline."""
|
||||
from codeflash.languages.python.context.code_context_extractor import get_function_sources_from_jedi
|
||||
|
||||
try:
|
||||
import jedi
|
||||
|
||||
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
|
||||
from codeflash.optimization.function_context import belongs_to_function_qualified
|
||||
|
||||
script = jedi.Script(path=function.file_path, project=jedi.Project(path=project_root))
|
||||
file_refs = script.get_names(all_scopes=True, definitions=False, references=True)
|
||||
|
||||
qualified_name = function.qualified_name
|
||||
|
||||
for ref in file_refs:
|
||||
if not ref.full_name or not belongs_to_function_qualified(ref, qualified_name):
|
||||
continue
|
||||
|
||||
try:
|
||||
definitions = ref.goto(follow_imports=True, follow_builtin_imports=False)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
for definition in definitions:
|
||||
definition_path = definition.module_path
|
||||
if definition_path is None:
|
||||
continue
|
||||
|
||||
# Check if it's a valid helper (in project, not in target function)
|
||||
is_valid = (
|
||||
str(definition_path).startswith(str(project_root))
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and definition.full_name
|
||||
and not belongs_to_function_qualified(definition, qualified_name)
|
||||
and definition.type == "function"
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
helper_qualified_name = get_qualified_name(definition.module_name, definition.full_name)
|
||||
# Get source code
|
||||
try:
|
||||
helper_source = definition.get_line_code()
|
||||
except Exception:
|
||||
helper_source = ""
|
||||
|
||||
helpers.append(
|
||||
HelperFunction(
|
||||
name=definition.name,
|
||||
qualified_name=helper_qualified_name,
|
||||
file_path=definition_path,
|
||||
source_code=helper_source,
|
||||
start_line=definition.line or 1,
|
||||
end_line=definition.line or 1,
|
||||
)
|
||||
)
|
||||
|
||||
_dict, sources = get_function_sources_from_jedi(
|
||||
{function.file_path: {function.qualified_name}}, project_root
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to find helpers for %s: %s", function.function_name, e)
|
||||
return []
|
||||
|
||||
return helpers
|
||||
return function_sources_to_helpers(sources)
|
||||
|
||||
def find_references(
|
||||
self, function: FunctionToOptimize, project_root: Path, tests_root: Path | None = None, max_files: int = 500
|
||||
|
|
@ -730,15 +657,6 @@ class PythonSupport:
|
|||
"""
|
||||
return ".py"
|
||||
|
||||
def get_comment_prefix(self) -> str:
|
||||
"""Get the comment prefix for Python.
|
||||
|
||||
Returns:
|
||||
Python single-line comment prefix.
|
||||
|
||||
"""
|
||||
return "#"
|
||||
|
||||
def find_test_root(self, project_root: Path) -> Path | None:
|
||||
"""Find the test root directory for a Python project.
|
||||
|
||||
|
|
|
|||
|
|
@ -72,8 +72,6 @@ from codeflash.code_utils.line_profile_utils import add_decorator_imports, conta
|
|||
from codeflash.code_utils.shell_utils import make_env_with_project_root
|
||||
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.context import code_context_extractor
|
||||
from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions
|
||||
from codeflash.discovery.functions_to_optimize import was_function_previously_optimized
|
||||
from codeflash.either import Failure, Success, is_successful
|
||||
from codeflash.languages import is_python
|
||||
|
|
@ -81,6 +79,11 @@ from codeflash.languages.base import Language
|
|||
from codeflash.languages.current import current_language_support, is_typescript
|
||||
from codeflash.languages.javascript.module_system import detect_module_system
|
||||
from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files
|
||||
from codeflash.languages.python.context import code_context_extractor
|
||||
from codeflash.languages.python.context.unused_definition_remover import (
|
||||
detect_unused_helper_functions,
|
||||
revert_unused_helper_functions,
|
||||
)
|
||||
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown
|
||||
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId
|
||||
from codeflash.models.ExperimentMetadata import ExperimentMetadata
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ codeflash/result/explanation.py
|
|||
codeflash/result/critic.py
|
||||
codeflash/version.py
|
||||
codeflash/optimization/__init__.py
|
||||
codeflash/context/__init__.py
|
||||
codeflash/context/code_context_extractor.py
|
||||
codeflash/languages/python/context/__init__.py
|
||||
codeflash/languages/python/context/code_context_extractor.py
|
||||
codeflash/discovery/__init__.py
|
||||
codeflash/__init__.py
|
||||
codeflash/models/ExperimentMetadata.py
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
|
|
|
|||
|
|
@ -10,17 +10,15 @@ import pytest
|
|||
|
||||
from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_global_assignments
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.context.code_context_extractor import (
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.context.code_context_extractor import (
|
||||
collect_names_from_annotation,
|
||||
enrich_testgen_context,
|
||||
extract_classes_from_type_hint,
|
||||
extract_imports_for_class,
|
||||
get_code_optimization_context,
|
||||
get_external_base_class_inits,
|
||||
get_external_class_inits,
|
||||
get_imported_class_definitions,
|
||||
resolve_transitive_type_deps,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
|
|
@ -1009,7 +1007,7 @@ class HelperClass:
|
|||
)
|
||||
# In this scenario, the read-writable code is too long, so we abort.
|
||||
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
get_code_optimization_context(function_to_optimize, opt.args.project_root, optim_token_limit=8000)
|
||||
|
||||
|
||||
def test_example_class_token_limit_4(tmp_path: Path) -> None:
|
||||
|
|
@ -1062,7 +1060,7 @@ class HelperClass:
|
|||
|
||||
# In this scenario, the read-writable code context becomes too large because the __init__ function is referencing the global x variable instead of the class attribute self.x, so we abort.
|
||||
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
get_code_optimization_context(function_to_optimize, opt.args.project_root, optim_token_limit=8000)
|
||||
|
||||
|
||||
def test_example_class_token_limit_5(tmp_path: Path) -> None:
|
||||
|
|
@ -2422,7 +2420,7 @@ class OuterClass:
|
|||
assert "__init__" not in hashing_context # Should not contain __init__ methods
|
||||
|
||||
# Verify nested classes are excluded from the hashing context
|
||||
# The prune_cst_for_code_hashing function should not recurse into nested classes
|
||||
# The prune_cst function in hashing mode should not recurse into nested classes
|
||||
assert "class NestedClass:" not in hashing_context # Nested class definition should not be present
|
||||
|
||||
# The target method will reference NestedClass, but the actual nested class definition should not be included
|
||||
|
|
@ -3275,8 +3273,8 @@ def dump_layout(layout_type, layout):
|
|||
assert testgen_context.count("def __init__") >= 2, "Both __init__ methods should be in testgen context"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions extracts class definitions from project modules."""
|
||||
def test_enrich_testgen_context_extracts_project_classes(tmp_path: Path) -> None:
|
||||
"""Test that enrich_testgen_context extracts class definitions from project modules."""
|
||||
# Create a package structure with two modules
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3325,8 +3323,8 @@ class Accumulator:
|
|||
# Create CodeStringsMarkdown from the chunking module (simulating testgen context)
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=chunking_code, file_path=chunking_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Verify Element class was extracted
|
||||
assert len(result.code_strings) == 1, "Should extract exactly one class (Element)"
|
||||
|
|
@ -3339,8 +3337,8 @@ class Accumulator:
|
|||
assert "import abc" in extracted_code, "Should include necessary imports for base class"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_skips_existing_definitions(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions skips classes already defined in context."""
|
||||
def test_enrich_testgen_context_skips_existing_definitions(tmp_path: Path) -> None:
|
||||
"""Test that enrich_testgen_context skips classes already defined in context."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3373,15 +3371,15 @@ class User:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code_with_local_def, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should NOT extract Element since it's already defined locally
|
||||
assert len(result.code_strings) == 0, "Should not extract classes already defined in context"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_skips_third_party(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions skips third-party/stdlib imports."""
|
||||
def test_enrich_testgen_context_skips_third_party(tmp_path: Path) -> None:
|
||||
"""Test that enrich_testgen_context skips third-party/stdlib imports."""
|
||||
# Create a simple package
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3402,15 +3400,15 @@ class MyClass:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should not extract any classes (Path, Optional, dataclass are stdlib/third-party)
|
||||
assert len(result.code_strings) == 0, "Should not extract stdlib/third-party classes"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_handles_multiple_imports(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions handles multiple class imports."""
|
||||
def test_enrich_testgen_context_handles_multiple_imports(tmp_path: Path) -> None:
|
||||
"""Test that enrich_testgen_context handles multiple class imports."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3446,8 +3444,8 @@ class Processor:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract both TypeA and TypeB (but not TypeC since it's not imported)
|
||||
assert len(result.code_strings) == 2, "Should extract exactly two classes (TypeA, TypeB)"
|
||||
|
|
@ -3458,8 +3456,8 @@ class Processor:
|
|||
assert "class TypeC" not in all_extracted_code, "Should NOT contain TypeC (not imported)"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_includes_dataclass_decorators(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions includes decorators when extracting dataclasses."""
|
||||
def test_enrich_testgen_context_includes_dataclass_decorators(tmp_path: Path) -> None:
|
||||
"""Test that enrich_testgen_context includes decorators when extracting dataclasses."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3496,8 +3494,8 @@ class ConfigRegistry:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract both LLMConfigBase (base class) and LLMConfig
|
||||
assert len(result.code_strings) == 2, "Should extract both LLMConfig and its base class LLMConfigBase"
|
||||
|
|
@ -3521,7 +3519,7 @@ class ConfigRegistry:
|
|||
assert "from dataclasses import" in all_extracted_code, "Should include dataclasses import"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_extracts_imports_for_decorated_classes(tmp_path: Path) -> None:
|
||||
"""Test that extract_imports_for_class includes decorator and type annotation imports."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
|
|
@ -3552,7 +3550,7 @@ def create_config() -> Config:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1, "Should extract Config class"
|
||||
extracted_code = result.code_strings[0].code
|
||||
|
|
@ -3724,7 +3722,7 @@ class MyClass:
|
|||
assert result.count("from typing import Optional") == 1
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_multiple_decorators(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_multiple_decorators(tmp_path: Path) -> None:
|
||||
"""Test that classes with multiple decorators are extracted correctly."""
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
|
|
@ -3755,7 +3753,7 @@ def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
extracted_code = result.code_strings[0].code
|
||||
|
|
@ -3766,7 +3764,7 @@ def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
|
|||
assert "class OrderedConfig" in extracted_code
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_extracts_multilevel_inheritance(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_extracts_multilevel_inheritance(tmp_path: Path) -> None:
|
||||
"""Test that base classes are recursively extracted for multi-level inheritance.
|
||||
|
||||
This is critical for understanding dataclass constructor signatures, as fields
|
||||
|
|
@ -3826,8 +3824,8 @@ class ConfigRegistry:
|
|||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
# Call enrich_testgen_context
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract 4 classes: GrandParentConfig, ParentConfig, ChildConfig, RouterConfig
|
||||
# (all classes needed to understand the full inheritance hierarchy)
|
||||
|
|
@ -3862,7 +3860,7 @@ class ConfigRegistry:
|
|||
assert "model_list: list" in all_extracted_code, "Should include model_list field from Router"
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_extracts_userdict(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_extracts_userdict(tmp_path: Path) -> None:
|
||||
"""Extracts __init__ from collections.UserDict when a class inherits from it."""
|
||||
code = """from collections import UserDict
|
||||
|
||||
|
|
@ -3873,7 +3871,7 @@ class MyCustomDict(UserDict):
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
code_string = result.code_strings[0]
|
||||
|
|
@ -3891,8 +3889,8 @@ class UserDict:
|
|||
assert code_string.file_path.as_posix().endswith("collections/__init__.py")
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_skips_project_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty when base class is from the project, not external."""
|
||||
def test_enrich_testgen_context_skips_unresolvable_base_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty when base class module cannot be resolved."""
|
||||
child_code = """from base import ProjectBase
|
||||
|
||||
class Child(ProjectBase):
|
||||
|
|
@ -3902,12 +3900,12 @@ class Child(ProjectBase):
|
|||
child_path.write_text(child_code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=child_code, file_path=child_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_skips_builtin_base_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty for builtin classes like list that have no inspectable source."""
|
||||
code = """class MyList(list):
|
||||
pass
|
||||
|
|
@ -3916,12 +3914,12 @@ def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_deduplicates(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_deduplicates(tmp_path: Path) -> None:
|
||||
"""Extracts the same external base class only once even when inherited multiple times."""
|
||||
code = """from collections import UserDict
|
||||
|
||||
|
|
@ -3935,7 +3933,7 @@ class MyDict2(UserDict):
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
expected_code = """\
|
||||
|
|
@ -3950,7 +3948,7 @@ class UserDict:
|
|||
assert result.code_strings[0].code == expected_code
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_empty_when_no_inheritance(tmp_path: Path) -> None:
|
||||
"""Returns empty when there are no external base classes."""
|
||||
code = """class SimpleClass:
|
||||
pass
|
||||
|
|
@ -3959,7 +3957,7 @@ def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path)
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
|
@ -4103,127 +4101,8 @@ class MyCustomDict(UserDict):
|
|||
assert "self.data = {}" in testgen_context, "UserDict __init__ body should be included"
|
||||
|
||||
|
||||
def test_read_only_code_removed_when_exceeds_limit(tmp_path: Path) -> None:
|
||||
"""Test read-only code is completely removed when it exceeds token limit even without docstrings.
|
||||
|
||||
This covers lines 152-153 in code_context_extractor.py where read_only_context_code is set
|
||||
to empty string when it still exceeds the token limit after docstring removal.
|
||||
"""
|
||||
# Create a second-degree helper with large implementation that has no docstrings
|
||||
# Second-degree helpers go into read-only context
|
||||
long_lines = [" x = 0"]
|
||||
for i in range(150):
|
||||
long_lines.append(f" x = x + {i}")
|
||||
long_lines.append(" return x")
|
||||
long_body = "\n".join(long_lines)
|
||||
|
||||
code = f"""
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def target_method(self):
|
||||
return first_helper()
|
||||
|
||||
|
||||
def first_helper():
|
||||
# First degree helper - calls second degree
|
||||
return second_helper()
|
||||
|
||||
|
||||
def second_helper():
|
||||
# Second degree helper - goes into read-only context
|
||||
{long_body}
|
||||
"""
|
||||
file_path = tmp_path / "test_code.py"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method", file_path=file_path, parents=[FunctionParent(name="MyClass", type="ClassDef")]
|
||||
)
|
||||
|
||||
# Use a small optim_token_limit that allows read-writable but not read-only
|
||||
# Read-writable is ~48 tokens, read-only is ~600 tokens
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
optim_token_limit=100, # Small limit to trigger read-only removal
|
||||
)
|
||||
|
||||
# The read-only context should be empty because it exceeded the limit
|
||||
assert code_ctx.read_only_context_code == "", "Read-only code should be removed when exceeding token limit"
|
||||
|
||||
|
||||
def test_testgen_removes_imported_classes_on_overflow(tmp_path: Path) -> None:
|
||||
"""Test testgen context removes imported class definitions when exceeding token limit.
|
||||
|
||||
This covers lines 176-186 in code_context_extractor.py where:
|
||||
- Testgen context exceeds limit (line 175)
|
||||
- Removing docstrings still exceeds (line 175 again)
|
||||
- Removing imported classes succeeds (line 177-183)
|
||||
"""
|
||||
# Create a package structure with a large type class used only in type annotations
|
||||
# This ensures get_imported_class_definitions extracts the full class
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# Create a large class with methods that will be extracted via get_imported_class_definitions
|
||||
# Use methods WITHOUT docstrings so removing docstrings won't help much
|
||||
many_methods = "\n".join([f" def method_{i}(self):\n return {i}" for i in range(100)])
|
||||
type_class_code = f'''
|
||||
class TypeClass:
|
||||
"""A type class for annotations."""
|
||||
|
||||
def __init__(self, value: int):
|
||||
self.value = value
|
||||
|
||||
{many_methods}
|
||||
'''
|
||||
type_class_path = package_dir / "types.py"
|
||||
type_class_path.write_text(type_class_code, encoding="utf-8")
|
||||
|
||||
# Main module uses TypeClass only in annotation (not instantiated)
|
||||
# This triggers get_imported_class_definitions to extract the full class
|
||||
main_code = """
|
||||
from mypackage.types import TypeClass
|
||||
|
||||
def target_function(obj: TypeClass) -> int:
|
||||
return obj.value
|
||||
"""
|
||||
main_path = package_dir / "main.py"
|
||||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
func_to_optimize = FunctionToOptimize(function_name="target_function", file_path=main_path, parents=[])
|
||||
|
||||
# Use a testgen_token_limit that:
|
||||
# - Is exceeded by full context with imported class (~1500 tokens)
|
||||
# - Is exceeded even after removing docstrings
|
||||
# - But fits when imported class is removed (~40 tokens)
|
||||
code_ctx = get_code_optimization_context(
|
||||
function_to_optimize=func_to_optimize,
|
||||
project_root_path=tmp_path,
|
||||
testgen_token_limit=200, # Small limit to trigger imported class removal
|
||||
)
|
||||
|
||||
# The testgen context should exist (didn't raise ValueError)
|
||||
testgen_context = code_ctx.testgen_context.markdown
|
||||
assert testgen_context, "Testgen context should not be empty"
|
||||
|
||||
# The target function should still be there
|
||||
assert "def target_function" in testgen_context, "Target function should be in testgen context"
|
||||
|
||||
# The large imported class should NOT be included (removed due to token limit)
|
||||
assert "class TypeClass" not in testgen_context, (
|
||||
"TypeClass should be removed from testgen context when exceeding token limit"
|
||||
)
|
||||
|
||||
|
||||
def test_testgen_raises_when_all_fallbacks_fail(tmp_path: Path) -> None:
|
||||
"""Test that ValueError is raised when testgen context exceeds limit even after all fallbacks.
|
||||
|
||||
This covers line 186 in code_context_extractor.py.
|
||||
"""
|
||||
def test_testgen_raises_when_exceeds_limit(tmp_path: Path) -> None:
|
||||
"""Test that ValueError is raised when testgen context exceeds token limit."""
|
||||
# Create a function with a very long body that exceeds limits even without imports/docstrings
|
||||
long_lines = [" x = 0"]
|
||||
for i in range(200):
|
||||
|
|
@ -4249,7 +4128,7 @@ def target_function():
|
|||
)
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_attribute_base(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_attribute_base(tmp_path: Path) -> None:
|
||||
"""Test handling of base class accessed as module.ClassName (ast.Attribute).
|
||||
|
||||
This covers line 616 in code_context_extractor.py.
|
||||
|
|
@ -4265,7 +4144,7 @@ class MyDict(UserDict):
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract UserDict __init__
|
||||
assert len(result.code_strings) == 1
|
||||
|
|
@ -4273,7 +4152,7 @@ class MyDict(UserDict):
|
|||
assert "def __init__" in result.code_strings[0].code
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_no_init_method(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_no_init_method(tmp_path: Path) -> None:
|
||||
"""Test handling when base class has no __init__ method.
|
||||
|
||||
This covers line 641 in code_context_extractor.py.
|
||||
|
|
@ -4288,7 +4167,7 @@ class MyProtocol(Protocol):
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Protocol's __init__ can't be easily inspected, should handle gracefully
|
||||
# Result may be empty or contain Protocol based on implementation
|
||||
|
|
@ -4377,7 +4256,7 @@ class MyClass:
|
|||
|
||||
|
||||
def test_imported_class_definitions_module_path_none(tmp_path: Path) -> None:
|
||||
"""Test handling when module_path is None in get_imported_class_definitions.
|
||||
"""Test handling when module_path is None in enrich_testgen_context.
|
||||
|
||||
This covers line 560 in code_context_extractor.py.
|
||||
"""
|
||||
|
|
@ -4393,123 +4272,12 @@ class MyClass:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should handle gracefully and return empty or partial results
|
||||
assert isinstance(result.code_strings, list)
|
||||
|
||||
|
||||
def test_get_imported_names_import_star(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles import * correctly.
|
||||
|
||||
This covers lines 808-809 and 824-825 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
# Test regular import *
|
||||
# Note: "import *" is not valid Python, but "from x import *" is
|
||||
from_import_star = cst.parse_statement("from os import *")
|
||||
assert isinstance(from_import_star, cst.SimpleStatementLine)
|
||||
import_node = from_import_star.body[0]
|
||||
assert isinstance(import_node, cst.ImportFrom)
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert result == {"*"}
|
||||
|
||||
|
||||
def test_get_imported_names_aliased_import(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles aliased imports correctly.
|
||||
|
||||
This covers lines 812-813 and 828-829 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test import with alias
|
||||
import_stmt = cst.parse_statement("import numpy as np")
|
||||
assert isinstance(import_stmt, cst.SimpleStatementLine)
|
||||
import_node = import_stmt.body[0]
|
||||
assert isinstance(import_node, cst.Import)
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert "np" in result
|
||||
|
||||
# Test from import with alias
|
||||
from_import_stmt = cst.parse_statement("from os import path as ospath")
|
||||
assert isinstance(from_import_stmt, cst.SimpleStatementLine)
|
||||
from_import_node = from_import_stmt.body[0]
|
||||
assert isinstance(from_import_node, cst.ImportFrom)
|
||||
|
||||
result2 = get_imported_names(from_import_node)
|
||||
assert "ospath" in result2
|
||||
|
||||
|
||||
def test_get_imported_names_dotted_import(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles dotted imports correctly.
|
||||
|
||||
This covers lines 816-822 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test dotted import like "import os.path"
|
||||
import_stmt = cst.parse_statement("import os.path")
|
||||
assert isinstance(import_stmt, cst.SimpleStatementLine)
|
||||
import_node = import_stmt.body[0]
|
||||
assert isinstance(import_node, cst.Import)
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert "os" in result
|
||||
|
||||
|
||||
def test_used_name_collector_comprehensive(tmp_path: Path) -> None:
|
||||
"""Test UsedNameCollector handles various node types.
|
||||
|
||||
This covers lines 767-801 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import UsedNameCollector
|
||||
|
||||
code = """
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
x: int = 1
|
||||
y = os.path.join("a", "b")
|
||||
|
||||
class MyClass:
|
||||
z = 10
|
||||
|
||||
def my_func():
|
||||
pass
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
collector = UsedNameCollector()
|
||||
# In libcst, the walker traverses the module
|
||||
cst.MetadataWrapper(module).visit(collector)
|
||||
|
||||
# Check used names
|
||||
assert "os" in collector.used_names
|
||||
assert "int" in collector.used_names
|
||||
assert "List" in collector.used_names
|
||||
|
||||
# Check defined names
|
||||
assert "x" in collector.defined_names
|
||||
assert "y" in collector.defined_names
|
||||
assert "MyClass" in collector.defined_names
|
||||
assert "my_func" in collector.defined_names
|
||||
|
||||
# Check external names (used but not defined)
|
||||
external = collector.get_external_names()
|
||||
assert "os" in external
|
||||
assert "x" not in external # x is defined
|
||||
|
||||
|
||||
def test_imported_class_with_base_in_same_module(tmp_path: Path) -> None:
|
||||
"""Test that imported classes with bases in the same module are extracted correctly.
|
||||
|
||||
|
|
@ -4549,52 +4317,13 @@ def target_function(obj: DerivedClass) -> bool:
|
|||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=main_code, file_path=main_path)])
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should extract the inheritance chain
|
||||
all_code = "\n".join(cs.code for cs in result.code_strings)
|
||||
assert "class BaseClass" in all_code or "class DerivedClass" in all_code
|
||||
|
||||
|
||||
def test_get_imported_names_from_import_without_alias(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles from imports without aliases.
|
||||
|
||||
This covers lines 830-831 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test from import without alias
|
||||
from_import_stmt = cst.parse_statement("from os import path, getcwd")
|
||||
assert isinstance(from_import_stmt, cst.SimpleStatementLine)
|
||||
from_import_node = from_import_stmt.body[0]
|
||||
assert isinstance(from_import_node, cst.ImportFrom)
|
||||
|
||||
result = get_imported_names(from_import_node)
|
||||
assert "path" in result
|
||||
assert "getcwd" in result
|
||||
|
||||
|
||||
def test_get_imported_names_regular_import(tmp_path: Path) -> None:
|
||||
"""Test get_imported_names handles regular imports.
|
||||
|
||||
This covers lines 814-815 in code_context_extractor.py.
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.context.code_context_extractor import get_imported_names
|
||||
|
||||
# Test regular import without alias
|
||||
import_stmt = cst.parse_statement("import json")
|
||||
assert isinstance(import_stmt, cst.SimpleStatementLine)
|
||||
import_node = import_stmt.body[0]
|
||||
assert isinstance(import_node, cst.Import)
|
||||
|
||||
result = get_imported_names(import_node)
|
||||
assert "json" in result
|
||||
|
||||
|
||||
def test_augmented_assignment_not_in_context(tmp_path: Path) -> None:
|
||||
"""Test that augmented assignments are handled but not included unless used.
|
||||
|
||||
|
|
@ -4625,7 +4354,7 @@ class MyClass:
|
|||
assert "counter" in read_writable
|
||||
|
||||
|
||||
def test_get_external_class_inits_extracts_click_option(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_extracts_click_option(tmp_path: Path) -> None:
|
||||
"""Extracts __init__ from click.Option when directly imported."""
|
||||
code = """from click import Option
|
||||
|
||||
|
|
@ -4636,7 +4365,7 @@ def my_func(opt: Option) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
code_string = result.code_strings[0]
|
||||
|
|
@ -4645,8 +4374,8 @@ def my_func(opt: Option) -> None:
|
|||
assert code_string.file_path is not None and "click" in code_string.file_path.as_posix()
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_project_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty when imported class is from the project, not external."""
|
||||
def test_enrich_testgen_context_extracts_project_class_defs(tmp_path: Path) -> None:
|
||||
"""Extracts project class definitions via jedi resolution."""
|
||||
# Create a project module with a class
|
||||
(tmp_path / "mymodule.py").write_text("class ProjectClass:\n pass\n", encoding="utf-8")
|
||||
|
||||
|
|
@ -4659,12 +4388,13 @@ def my_func(obj: ProjectClass) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
assert len(result.code_strings) == 1
|
||||
assert "class ProjectClass" in result.code_strings[0].code
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_non_classes(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_skips_non_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty when imported name is a function, not a class."""
|
||||
code = """from collections import OrderedDict
|
||||
from os.path import join
|
||||
|
|
@ -4676,7 +4406,7 @@ def my_func() -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# join is a function, not a class — should be skipped
|
||||
# OrderedDict is a class and should be included
|
||||
|
|
@ -4684,8 +4414,8 @@ def my_func() -> None:
|
|||
assert not any("join" in name for name in class_names)
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_already_defined_classes(tmp_path: Path) -> None:
|
||||
"""Skips classes already defined in the context (e.g., added by get_imported_class_definitions)."""
|
||||
def test_enrich_testgen_context_skips_already_defined_classes(tmp_path: Path) -> None:
|
||||
"""Skips classes already defined in the context (e.g., added by enrich_testgen_context)."""
|
||||
code = """from collections import UserDict
|
||||
|
||||
class UserDict:
|
||||
|
|
@ -4699,14 +4429,14 @@ def my_func(d: UserDict) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# UserDict is already defined in the context, so it should be skipped
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_builtins(tmp_path: Path) -> None:
|
||||
"""Returns empty for builtin classes like list/dict that have no inspectable source."""
|
||||
def test_enrich_testgen_context_skips_builtin_annotations(tmp_path: Path) -> None:
|
||||
"""Returns empty for builtin type annotations like list/dict that are not imported."""
|
||||
code = """x: list = []
|
||||
y: dict = {}
|
||||
|
||||
|
|
@ -4717,12 +4447,12 @@ def my_func() -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_class_inits_skips_object_init(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_skips_object_init(tmp_path: Path) -> None:
|
||||
"""Skips classes whose __init__ is just object.__init__ (trivial)."""
|
||||
# enum.Enum has a metaclass-based __init__, but individual enum members
|
||||
# effectively use object.__init__. Use a class we know has object.__init__.
|
||||
|
|
@ -4735,14 +4465,14 @@ def my_func(q: QName) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# QName has its own __init__, so it should be included if it's in site-packages.
|
||||
# But since it's stdlib (not site-packages), it should be skipped.
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_empty_when_no_imports(tmp_path: Path) -> None:
|
||||
"""Returns empty when there are no from-imports."""
|
||||
code = """def my_func() -> None:
|
||||
pass
|
||||
|
|
@ -4751,7 +4481,7 @@ def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
|
@ -4840,17 +4570,17 @@ def test_resolve_transitive_type_deps_handles_failure_gracefully() -> None:
|
|||
"""Returns empty list for a class where get_type_hints fails."""
|
||||
|
||||
class BadClass:
|
||||
def __init__(self, x: "NonexistentType") -> None: # type: ignore[name-defined] # noqa: F821
|
||||
def __init__(self, x: NonexistentType) -> None: # type: ignore[name-defined] # noqa: F821
|
||||
pass
|
||||
|
||||
result = resolve_transitive_type_deps(BadClass)
|
||||
assert result == []
|
||||
|
||||
|
||||
# --- Integration tests for transitive resolution in get_external_class_inits ---
|
||||
# --- Integration tests for transitive resolution in enrich_testgen_context ---
|
||||
|
||||
|
||||
def test_get_external_class_inits_transitive_deps(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_transitive_deps(tmp_path: Path) -> None:
|
||||
"""Extracts transitive type dependencies from __init__ annotations."""
|
||||
code = """from click import Context
|
||||
|
||||
|
|
@ -4861,7 +4591,7 @@ def my_func(ctx: Context) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
class_names = {cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings}
|
||||
assert "Context" in class_names
|
||||
|
|
@ -4869,7 +4599,7 @@ def my_func(ctx: Context) -> None:
|
|||
assert "Command" in class_names
|
||||
|
||||
|
||||
def test_get_external_class_inits_no_infinite_loops(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_no_infinite_loops(tmp_path: Path) -> None:
|
||||
"""Handles classes with circular type references without infinite loops."""
|
||||
# click.Context references Command, and Command references Context back
|
||||
# This should terminate without issues due to the processed_classes set
|
||||
|
|
@ -4882,13 +4612,13 @@ def my_func(ctx: Context) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# Should complete without hanging; just verify we got results
|
||||
assert len(result.code_strings) >= 1
|
||||
|
||||
|
||||
def test_get_external_class_inits_no_duplicate_stubs(tmp_path: Path) -> None:
|
||||
def test_enrich_testgen_context_no_duplicate_stubs(tmp_path: Path) -> None:
|
||||
"""Does not emit duplicate stubs for the same class name."""
|
||||
code = """from click import Context
|
||||
|
||||
|
|
@ -4899,7 +4629,7 @@ def my_func(ctx: Context) -> None:
|
|||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
class_names = [cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings]
|
||||
assert len(class_names) == len(set(class_names)), f"Duplicate class stubs found: {class_names}"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from textwrap import dedent
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.models.models import CodeContextType
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from textwrap import dedent
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.models.models import CodeContextType
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from textwrap import dedent
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.languages.python.context.code_context_extractor import parse_code_and_prune_cst
|
||||
from codeflash.models.models import CodeContextType
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -20,14 +20,12 @@ All assertions use strict string equality to verify exact extraction output.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context_for_language
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context_for_language
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -106,9 +106,9 @@ class TestJavaScriptCodeContext:
|
|||
def test_extract_code_context_for_javascript(self, js_project_dir):
|
||||
"""Test extracting code context for a JavaScript function."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.JAVASCRIPT
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ These tests verify the full optimization pipeline including:
|
|||
This is the JavaScript equivalent of test_instrument_tests.py for Python.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -71,9 +70,9 @@ module.exports = { add };
|
|||
def test_code_context_preserves_language(self, tmp_path):
|
||||
"""Verify language is preserved in code context extraction."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
|
||||
|
|
@ -164,7 +163,7 @@ export function add(a: number, b: number): number {
|
|||
|
||||
# Mock the AI service request
|
||||
ai_client = AiServiceClient()
|
||||
with patch.object(ai_client, 'make_ai_service_request') as mock_request:
|
||||
with patch.object(ai_client, "make_ai_service_request") as mock_request:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
|
|
@ -191,8 +190,8 @@ export function add(a: number, b: number): number {
|
|||
# Verify the request was made with correct language
|
||||
assert mock_request.called, "API request should have been made"
|
||||
call_args = mock_request.call_args
|
||||
payload = call_args[1].get('payload', call_args[0][1] if len(call_args[0]) > 1 else {})
|
||||
assert payload.get('language') == 'typescript', \
|
||||
payload = call_args[1].get("payload", call_args[0][1] if len(call_args[0]) > 1 else {})
|
||||
assert payload.get("language") == "typescript", \
|
||||
f"Expected language='typescript', got language='{payload.get('language')}'"
|
||||
|
||||
|
||||
|
|
@ -462,7 +461,7 @@ class TestHelperFunctionLanguageAttribute:
|
|||
"""Verify helper functions have language='javascript' for .js files."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current, get_language_support
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
lang_current._current_language = Language.JAVASCRIPT
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class TestTypeScriptFunctionDiscovery:
|
|||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f:
|
||||
f.write("""
|
||||
f.write(r"""
|
||||
export function add(a: number, b: number): number {
|
||||
return a + b;
|
||||
}
|
||||
|
|
@ -123,9 +123,9 @@ class TestTypeScriptCodeContext:
|
|||
def test_extract_code_context_for_typescript(self, ts_project_dir):
|
||||
"""Test extracting code context for a TypeScript function."""
|
||||
skip_if_ts_not_supported()
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
|
||||
|
|
@ -201,7 +201,7 @@ function multiply(a: number, b: number): number {
|
|||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
|
||||
original_source = """
|
||||
original_source = r"""
|
||||
interface Config {
|
||||
timeout: number;
|
||||
retries: number;
|
||||
|
|
@ -212,7 +212,7 @@ function processConfig(config: Config): string {
|
|||
}
|
||||
"""
|
||||
|
||||
new_function = """function processConfig(config: Config): string {
|
||||
new_function = r"""function processConfig(config: Config): string {
|
||||
// Optimized with template caching
|
||||
const { timeout, retries } = config;
|
||||
return `timeout=\${timeout}, retries=\${retries}`;
|
||||
|
|
|
|||
|
|
@ -117,10 +117,10 @@ class TestVitestCodeContext:
|
|||
def test_extract_code_context_for_typescript(self, vitest_project_dir):
|
||||
"""Test extracting code context for a TypeScript function."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
|
||||
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
|
||||
from codeflash.languages.python.context.unused_definition_remover import remove_unused_definitions_by_function_names
|
||||
|
||||
|
||||
def test_variable_removal_only() -> None:
|
||||
|
|
|
|||
|
|
@ -5,8 +5,11 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.context.unused_definition_remover import (
|
||||
detect_unused_helper_functions,
|
||||
revert_unused_helper_functions,
|
||||
)
|
||||
from codeflash.models.models import CodeStringsMarkdown
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
|
|
|||
Loading…
Reference in a new issue