keep the refrenced global definitions

This commit is contained in:
ali 2025-11-21 20:10:26 +02:00
parent ad09525b7d
commit 15d2027bb0
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
3 changed files with 218 additions and 54 deletions

View file

@ -12,7 +12,11 @@ import libcst as cst
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
from codeflash.context.unused_definition_remover import (
collect_top_level_defs_with_usages,
extract_names_from_targets,
remove_unused_definitions_by_function_names,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
from codeflash.models.models import (
CodeContextType,
@ -29,6 +33,8 @@ if TYPE_CHECKING:
from jedi.api.classes import Name
from libcst import CSTNode
from codeflash.context.unused_definition_remover import UsageInfo
def get_code_optimization_context(
function_to_optimize: FunctionToOptimize,
@ -498,8 +504,10 @@ def parse_code_and_prune_cst(
) -> str:
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
module = cst.parse_module(code)
defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions)
if code_context_type == CodeContextType.READ_WRITABLE:
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions)
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions, defs_with_usages)
elif code_context_type == CodeContextType.READ_ONLY:
filtered_node, found_target = prune_cst_for_read_only_code(
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
@ -524,7 +532,7 @@ def parse_code_and_prune_cst(
def prune_cst_for_read_writable_code( # noqa: PLR0911
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
node: cst.CSTNode, target_functions: set[str], defs_with_usages: dict[str, UsageInfo], prefix: str = ""
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
@ -569,6 +577,21 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
return node.with_changes(body=cst.IndentedBlock(body=new_body)), found_target
if isinstance(node, cst.Assign):
for target in node.targets:
names = extract_names_from_targets(target.target)
for name in names:
if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function:
return node, True
return None, False
if isinstance(node, (cst.AnnAssign, cst.AugAssign)):
names = extract_names_from_targets(node.target)
for name in names:
if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function:
return node, True
return None, False
# For other nodes, we preserve them only if they contain target functions in their children.
section_names = get_section_names(node)
if not section_names:
@ -583,7 +606,9 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
new_children = []
section_found_target = False
for child in original_content:
filtered, found_target = prune_cst_for_read_writable_code(child, target_functions, prefix)
filtered, found_target = prune_cst_for_read_writable_code(
child, target_functions, defs_with_usages, prefix
)
if filtered:
new_children.append(filtered)
section_found_target |= found_target
@ -592,7 +617,9 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
found_any_target = True
updates[section] = new_children
elif original_content is not None:
filtered, found_target = prune_cst_for_read_writable_code(original_content, target_functions, prefix)
filtered, found_target = prune_cst_for_read_writable_code(
original_content, target_functions, defs_with_usages, prefix
)
if found_target:
found_any_target = True
if filtered:
@ -600,7 +627,6 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
if not found_any_target:
return None, False
return (node.with_changes(**updates) if updates else node), True

View file

@ -5,7 +5,7 @@ from collections import defaultdict
from dataclasses import dataclass, field
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union
import libcst as cst
@ -122,6 +122,8 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
class DependencyCollector(cst.CSTVisitor):
"""Collects dependencies between definitions using the visitor pattern with depth tracking."""
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
def __init__(self, definitions: dict[str, UsageInfo]) -> None:
super().__init__()
self.definitions = definitions
@ -259,8 +261,12 @@ class DependencyCollector(cst.CSTVisitor):
if self.processing_variable and name in self.current_variable_names:
return
# Check if name is a top-level definition we're tracking
if name in self.definitions and name != self.current_top_level_name:
# skip if we are refrencing a class attribute and not a top-level definition
if self.class_depth > 0:
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if parent is not None and isinstance(parent, cst.Attribute):
return
self.definitions[self.current_top_level_name].dependencies.add(name)
@ -293,13 +299,19 @@ class QualifiedFunctionUsageMarker:
def mark_used_definitions(self) -> None:
"""Find all qualified functions and mark them and their dependencies as used."""
# First identify all specified functions (including expanded ones)
functions_to_mark = [name for name in self.expanded_qualified_functions if name in self.definitions]
# Avoid list comprehension for set intersection
expanded_names = self.expanded_qualified_functions
defs = self.definitions
functions_to_mark = (
expanded_names & defs.keys()
if isinstance(expanded_names, set)
else [name for name in expanded_names if name in defs]
)
# For each specified function, mark it and all its dependencies as used
for func_name in functions_to_mark:
self.definitions[func_name].used_by_qualified_function = True
for dep in self.definitions[func_name].dependencies:
defs[func_name].used_by_qualified_function = True
for dep in defs[func_name].dependencies:
self.mark_as_used_recursively(dep)
def mark_as_used_recursively(self, name: str) -> None:
@ -457,7 +469,28 @@ def remove_unused_definitions_recursively( # noqa: PLR0911
return node, False
def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str:
def collect_top_level_defs_with_usages(
code: Union[str, cst.Module], qualified_function_names: set[str]
) -> dict[str, UsageInfo]:
"""Collect all top level definitions (classes, variables or functions) and their usages."""
module = code if isinstance(code, cst.Module) else cst.parse_module(code)
# Collect all definitions (top level classes, variables or function)
definitions = collect_top_level_definitions(module)
# Collect dependencies between definitions using the visitor pattern
wrapper = cst.MetadataWrapper(module)
dependency_collector = DependencyCollector(definitions)
wrapper.visit(dependency_collector)
# Mark definitions used by specified functions, and their dependencies recursively
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
usage_marker.mark_used_definitions()
return definitions
def remove_unused_definitions_by_function_names(
code: str, qualified_function_names: set[str]
) -> tuple[str, dict[str, UsageInfo]]:
"""Analyze a file and remove top level definitions not used by specified functions.
Top level definitions, in this context, are only classes, variables or functions.
@ -476,19 +509,10 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
return code
try:
# Collect all definitions (top level classes, variables or function)
definitions = collect_top_level_definitions(module)
# Collect dependencies between definitions using the visitor pattern
dependency_collector = DependencyCollector(definitions)
module.visit(dependency_collector)
# Mark definitions used by specified functions, and their dependencies recursively
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
usage_marker.mark_used_definitions()
defs_with_usages = collect_top_level_defs_with_usages(module, qualified_function_names)
# Apply the recursive removal transformation
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
modified_module, _ = remove_unused_definitions_recursively(module, defs_with_usages)
return modified_module.code if modified_module else "" # noqa: TRY300
except Exception as e:

View file

@ -459,6 +459,9 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
_P = ParamSpec("_P")
_KEY_T = TypeVar("_KEY_T")
_STORE_T = TypeVar("_STORE_T")
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
def __init__(self) -> None: ...
@ -517,6 +520,10 @@ class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
# If encoding fails, we should still return the result.
return result
_P = ParamSpec("_P")
_R = TypeVar("_R")
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
@ -752,7 +759,7 @@ def test_example_class_token_limit_1(tmp_path: Path) -> None:
)
code = f"""
class MyClass:
\"\"\"A class with a helper method.
\"\"\"A class with a helper method.
{docstring_filler}\"\"\"
def __init__(self):
self.x = 1
@ -910,7 +917,17 @@ class HelperClass:
return self.x
```
"""
expected_read_only_context = ""
expected_read_only_context = f'''```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
"""A class with a helper method. """
class HelperClass:
"""A helper class for MyClass."""
def __repr__(self):
"""Return a string representation of the HelperClass."""
return "HelperClass" + str(self.x)
```
'''
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
@ -984,6 +1001,59 @@ def test_example_class_token_limit_4(tmp_path: Path) -> None:
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
)
code = f"""
class MyClass:
\"\"\"A class with a helper method. \"\"\"
def __init__(self):
global x
x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
x = '{string_filler}'
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
def helper_method(self):
return self.x
"""
# Create a temporary Python file using pytest's tmp_path fixture
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
# In this scenario, the read-writable code context is too long because the __init_ function is reftencing the global x variable not the class attribute (x), so we abort.
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
def test_example_class_token_limit_5(tmp_path: Path) -> None:
string_filler = " ".join(
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
)
code = f"""
class MyClass:
\"\"\"A class with a helper method. \"\"\"
def __init__(self):
@ -1026,9 +1096,44 @@ class HelperClass:
ending_line=None,
)
# In this scenario, the testgen code context is too long, so we abort.
with pytest.raises(ValueError, match="Testgen code context has exceeded token limit, cannot proceed"):
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
# the global x variable shouldn't be included in any context type
assert code_ctx.read_writable_code.flat == '''# file: test_code.py
class MyClass:
def __init__(self):
self.x = 1
def target_method(self):
"""Docstring for target method"""
y = HelperClass().helper_method()
class HelperClass:
def __init__(self):
"""Initialize the HelperClass."""
self.x = 1
def helper_method(self):
return self.x
'''
assert code_ctx.testgen_context.flat == '''# file: test_code.py
class MyClass:
"""A class with a helper method. """
def __init__(self):
self.x = 1
def target_method(self):
"""Docstring for target method"""
y = HelperClass().helper_method()
class HelperClass:
"""A helper class for MyClass."""
def __init__(self):
"""Initialize the HelperClass."""
self.x = 1
def __repr__(self):
"""Return a string representation of the HelperClass."""
return "HelperClass" + str(self.x)
def helper_method(self):
return self.x
'''
def test_repo_helper() -> None:
@ -2070,8 +2175,17 @@ def get_system_details():
relative_path = file_path.relative_to(project_root)
expected_read_write_context = f"""
```python:utility_module.py
# Function that will be used in the main code
DEFAULT_PRECISION = "medium"
# Try-except block with variable definitions
try:
# Used variable in try block
CALCULATION_BACKEND = "numpy"
except ImportError:
# Used variable in except block
CALCULATION_BACKEND = "python"
# Function that will be used in the main code
def select_precision(precision, fallback_precision):
if precision is None:
return fallback_precision or DEFAULT_PRECISION
@ -2466,12 +2580,12 @@ def test_circular_deps():
project_root_path= Path(path_to_root),
)
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
def test_global_assignment_collector_with_async_function():
"""Test GlobalAssignmentCollector correctly identifies global assignments outside async functions."""
import libcst as cst
source_code = """
# Global assignment
GLOBAL_VAR = "global_value"
@ -2486,21 +2600,21 @@ async def async_function():
# Another global assignment
ANOTHER_GLOBAL = "another_global"
"""
tree = cst.parse_module(source_code)
collector = GlobalAssignmentCollector()
tree.visit(collector)
# Should collect global assignments but not the ones inside async function
assert len(collector.assignments) == 3
assert "GLOBAL_VAR" in collector.assignments
assert "OTHER_GLOBAL" in collector.assignments
assert "ANOTHER_GLOBAL" in collector.assignments
# Should not collect assignments from inside async function
assert "local_var" not in collector.assignments
assert "INNER_ASSIGNMENT" not in collector.assignments
# Verify assignment order
expected_order = ["GLOBAL_VAR", "OTHER_GLOBAL", "ANOTHER_GLOBAL"]
assert collector.assignment_order == expected_order
@ -2509,7 +2623,7 @@ ANOTHER_GLOBAL = "another_global"
def test_global_assignment_collector_nested_async_functions():
"""Test GlobalAssignmentCollector handles nested async functions correctly."""
import libcst as cst
source_code = """
# Global assignment
CONFIG = {"key": "value"}
@ -2517,38 +2631,38 @@ CONFIG = {"key": "value"}
def sync_function():
# Inside sync function - should not be collected
sync_local = "sync"
async def nested_async():
# Inside nested async function - should not be collected
nested_var = "nested"
return nested_var
return sync_local
async def async_function():
# Inside async function - should not be collected
async_local = "async"
def nested_sync():
# Inside nested function - should not be collected
deeply_nested = "deep"
return deeply_nested
return async_local
# Another global assignment
FINAL_GLOBAL = "final"
"""
tree = cst.parse_module(source_code)
collector = GlobalAssignmentCollector()
tree.visit(collector)
# Should only collect global-level assignments
assert len(collector.assignments) == 2
assert "CONFIG" in collector.assignments
assert "FINAL_GLOBAL" in collector.assignments
# Should not collect any assignments from inside functions
assert "sync_local" not in collector.assignments
assert "nested_var" not in collector.assignments
@ -2559,20 +2673,20 @@ FINAL_GLOBAL = "final"
def test_global_assignment_collector_mixed_async_sync_with_classes():
"""Test GlobalAssignmentCollector with async functions, sync functions, and classes."""
import libcst as cst
source_code = """
# Global assignments
GLOBAL_CONSTANT = "constant"
class TestClass:
# Class-level assignment - should not be collected
# Class-level assignment - should not be collected
class_var = "class_value"
def sync_method(self):
# Method assignment - should not be collected
method_var = "method"
return method_var
async def async_method(self):
# Async method assignment - should not be collected
async_method_var = "async_method"
@ -2592,24 +2706,24 @@ async def async_function():
ANOTHER_CONSTANT = 100
FINAL_ASSIGNMENT = {"data": "value"}
"""
tree = cst.parse_module(source_code)
collector = GlobalAssignmentCollector()
tree.visit(collector)
# Should only collect global-level assignments
assert len(collector.assignments) == 3
assert "GLOBAL_CONSTANT" in collector.assignments
assert "GLOBAL_CONSTANT" in collector.assignments
assert "ANOTHER_CONSTANT" in collector.assignments
assert "FINAL_ASSIGNMENT" in collector.assignments
# Should not collect assignments from inside any scoped blocks
assert "class_var" not in collector.assignments
assert "method_var" not in collector.assignments
assert "async_method_var" not in collector.assignments
assert "func_var" not in collector.assignments
assert "async_func_var" not in collector.assignments
# Verify correct order
expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"]
assert collector.assignment_order == expected_order