handle new added classes

This commit is contained in:
ali 2025-11-11 16:20:11 +02:00
parent 6696f079e9
commit b4294f8a90
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
3 changed files with 198 additions and 45 deletions

View file

@ -72,6 +72,33 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
return True
def find_insertion_index_after_imports(node: cst.Module) -> int:
"""Find the position of the last import statement in the top-level of the module."""
insert_index = 0
for i, stmt in enumerate(node.body):
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
)
is_conditional_import = isinstance(stmt, cst.If) and all(
isinstance(inner, cst.SimpleStatementLine)
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
for inner in stmt.body.body
)
if is_top_level_import or is_conditional_import:
insert_index = i + 1
# Stop scanning once we reach a class or function definition.
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
# Without this check, a stray import later in the file
# would incorrectly shift our insertion index below actual code definitions.
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
break
return insert_index
class GlobalAssignmentTransformer(cst.CSTTransformer):
"""Transforms global assignments in the original file with those from the new file."""
@ -122,32 +149,6 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
return updated_node
def _find_insertion_index(self, updated_node: cst.Module) -> int:
"""Find the position of the last import statement in the top-level of the module."""
insert_index = 0
for i, stmt in enumerate(updated_node.body):
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
)
is_conditional_import = isinstance(stmt, cst.If) and all(
isinstance(inner, cst.SimpleStatementLine)
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
for inner in stmt.body.body
)
if is_top_level_import or is_conditional_import:
insert_index = i + 1
# Stop scanning once we reach a class or function definition.
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
# Without this check, a stray import later in the file
# would incorrectly shift our insertion index below actual code definitions.
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
break
return insert_index
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
# Add any new assignments that weren't in the original file
new_statements = list(updated_node.body)
@ -161,7 +162,7 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
if assignments_to_append:
# after last top-level imports
insert_index = self._find_insertion_index(updated_node)
insert_index = find_insertion_index_after_imports(updated_node)
assignment_lines = [
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])

View file

@ -3,13 +3,18 @@ from __future__ import annotations
import ast
from collections import defaultdict
from functools import lru_cache
from itertools import chain
from typing import TYPE_CHECKING, Optional, TypeVar
import libcst as cst
from libcst.metadata import PositionProvider
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
from codeflash.code_utils.code_extractor import (
add_global_assignments,
add_needed_imports_from_module,
find_insertion_index_after_imports,
)
from codeflash.code_utils.config_parser import find_conftest_files
from codeflash.code_utils.formatter import sort_imports
from codeflash.code_utils.line_profile_utils import ImportAdder
@ -249,6 +254,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
] = {} # keys are (class_name, function_name)
self.new_functions: list[cst.FunctionDef] = []
self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
self.new_classes: list[cst.ClassDef] = []
self.current_class = None
self.modified_init_functions: dict[str, cst.FunctionDef] = {}
@ -271,6 +277,11 @@ class OptimFunctionCollector(cst.CSTVisitor):
self.current_class = node.name.value
parents = (FunctionParent(name=node.name.value, type="ClassDef"),)
# check if the class is new
if (node.name.value, ()) not in self.preexisting_objects:
self.new_classes.append(node)
for child_node in node.body.body:
if (
self.preexisting_objects
@ -290,6 +301,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
def __init__(
self,
modified_functions: Optional[dict[tuple[str | None, str], cst.FunctionDef]] = None,
new_classes: Optional[list[cst.ClassDef]] = None,
new_functions: Optional[list[cst.FunctionDef]] = None,
new_class_functions: Optional[dict[str, list[cst.FunctionDef]]] = None,
modified_init_functions: Optional[dict[str, cst.FunctionDef]] = None,
@ -297,6 +309,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
super().__init__()
self.modified_functions = modified_functions if modified_functions is not None else {}
self.new_functions = new_functions if new_functions is not None else []
self.new_classes = new_classes if new_classes is not None else []
self.new_class_functions = new_class_functions if new_class_functions is not None else defaultdict(list)
self.modified_init_functions: dict[str, cst.FunctionDef] = (
modified_init_functions if modified_init_functions is not None else {}
@ -335,19 +348,33 @@ class OptimFunctionReplacer(cst.CSTTransformer):
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
node = updated_node
max_function_index = None
class_index = None
max_class_index = None
for index, _node in enumerate(node.body):
if isinstance(_node, cst.FunctionDef):
max_function_index = index
if isinstance(_node, cst.ClassDef):
class_index = index
max_class_index = index
if self.new_classes:
existing_class_names = {_node.name.value for _node in node.body if isinstance(_node, cst.ClassDef)}
unique_classes = [
new_class for new_class in self.new_classes if new_class.name.value not in existing_class_names
]
if unique_classes:
new_classes_insertion_idx = max_class_index or find_insertion_index_after_imports(node)
new_body = list(
chain(node.body[:new_classes_insertion_idx], unique_classes, node.body[new_classes_insertion_idx:])
)
node = node.with_changes(body=new_body)
if max_function_index is not None:
node = node.with_changes(
body=(*node.body[: max_function_index + 1], *self.new_functions, *node.body[max_function_index + 1 :])
)
elif class_index is not None:
elif max_class_index is not None:
node = node.with_changes(
body=(*node.body[: class_index + 1], *self.new_functions, *node.body[class_index + 1 :])
body=(*node.body[: max_class_index + 1], *self.new_functions, *node.body[max_class_index + 1 :])
)
else:
node = node.with_changes(body=(*self.new_functions, *node.body))
@ -373,18 +400,20 @@ def replace_functions_in_file(
parsed_function_names.append((class_name, function_name))
# Collect functions we want to modify from the optimized code
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
optimized_module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
original_module = cst.parse_module(source_code)
visitor = OptimFunctionCollector(preexisting_objects, set(parsed_function_names))
module.visit(visitor)
optimized_module.visit(visitor)
# Replace these functions in the original code
transformer = OptimFunctionReplacer(
modified_functions=visitor.modified_functions,
new_classes=visitor.new_classes,
new_functions=visitor.new_functions,
new_class_functions=visitor.new_class_functions,
modified_init_functions=visitor.modified_init_functions,
)
original_module = cst.parse_module(source_code)
modified_tree = original_module.visit(transformer)
return modified_tree.code

View file

@ -215,7 +215,7 @@ class NewClass:
return other_function(self.name)
def new_function2(value):
return value
"""
"""
original_code = """import libcst as cst
from typing import Mandatory
@ -230,19 +230,28 @@ def other_function(st):
print("Salut monde")
"""
expected = """from typing import Mandatory
expected = """import libcst as cst
from typing import Mandatory
class NewClass:
def __init__(self, name):
self.name = name
def new_function(self, value: cst.Name):
return other_function(self.name)
def new_function2(value):
return value
print("Au revoir")
def yet_another_function(values):
return len(values)
def other_function(st):
return(st * 2)
def totally_new_function(value):
return value
def other_function(st):
return(st * 2)
print("Salut monde")
"""
@ -279,7 +288,7 @@ class NewClass:
return other_function(self.name)
def new_function2(value):
return value
"""
"""
original_code = """import libcst as cst
from typing import Mandatory
@ -296,17 +305,25 @@ print("Salut monde")
"""
expected = """from typing import Mandatory
class NewClass:
def __init__(self, name):
self.name = name
def new_function(self, value):
return other_function(self.name)
def new_function2(value):
return value
print("Au revoir")
def yet_another_function(values):
return len(values) + 2
def other_function(st):
return(st * 2)
def totally_new_function(value):
return value
def other_function(st):
return(st * 2)
print("Salut monde")
"""
@ -3619,4 +3636,110 @@ async def task():
await asyncio.sleep(1)
return "done"
'''
assert is_zero_diff(original_code, optimized_code)
assert is_zero_diff(original_code, optimized_code)
def test_code_replacement_with_new_helper_class() -> None:
optim_code = """from __future__ import annotations
import itertools
import re
from dataclasses import dataclass
from typing import Any, Callable, Iterator, Sequence
from bokeh.models import HoverTool, Plot, Tool
# Move the Item dataclass to module-level to avoid redefining it on every function call
@dataclass(frozen=True)
class _RepeatedToolItem:
obj: Tool
properties: dict[str, Any]
def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
# Pre-collect properties for all objects by group to avoid repeated calls
for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
grouped = list(group)
n = len(grouped)
if n > 1:
# Precompute all properties once for this group
props = [_RepeatedToolItem(obj, obj.properties_with_values()) for obj in grouped]
i = 0
while i < len(props) - 1:
head = props[i]
for j in range(i+1, len(props)):
item = props[j]
if item.properties == head.properties:
yield item.obj
i += 1
"""
original_code = """from __future__ import annotations
import itertools
import re
from bokeh.models import HoverTool, Plot, Tool
from dataclasses import dataclass
from typing import Any, Callable, Iterator, Sequence
def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
@dataclass(frozen=True)
class Item:
obj: Tool
properties: dict[str, Any]
key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
rest = [ Item(obj, obj.properties_with_values()) for obj in group ]
while len(rest) > 1:
head, *rest = rest
for item in rest:
if item.properties == head.properties:
yield item.obj
"""
expected = """from __future__ import annotations
import itertools
from bokeh.models import Tool
from dataclasses import dataclass
from typing import Any, Callable, Iterator
# Move the Item dataclass to module-level to avoid redefining it on every function call
@dataclass(frozen=True)
class _RepeatedToolItem:
obj: Tool
properties: dict[str, Any]
def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
# Pre-collect properties for all objects by group to avoid repeated calls
for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
grouped = list(group)
n = len(grouped)
if n > 1:
# Precompute all properties once for this group
props = [_RepeatedToolItem(obj, obj.properties_with_values()) for obj in grouped]
i = 0
while i < len(props) - 1:
head = props[i]
for j in range(i+1, len(props)):
item = props[j]
if item.properties == head.properties:
yield item.obj
i += 1
"""
function_names: list[str] = ["_collect_repeated_tools"]
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected