mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
handle new added classes
This commit is contained in:
parent
6696f079e9
commit
b4294f8a90
3 changed files with 198 additions and 45 deletions
|
|
@ -72,6 +72,33 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
|
|||
return True
|
||||
|
||||
|
||||
def find_insertion_index_after_imports(node: cst.Module) -> int:
|
||||
"""Find the position of the last import statement in the top-level of the module."""
|
||||
insert_index = 0
|
||||
for i, stmt in enumerate(node.body):
|
||||
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
|
||||
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
|
||||
)
|
||||
|
||||
is_conditional_import = isinstance(stmt, cst.If) and all(
|
||||
isinstance(inner, cst.SimpleStatementLine)
|
||||
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
|
||||
for inner in stmt.body.body
|
||||
)
|
||||
|
||||
if is_top_level_import or is_conditional_import:
|
||||
insert_index = i + 1
|
||||
|
||||
# Stop scanning once we reach a class or function definition.
|
||||
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
|
||||
# Without this check, a stray import later in the file
|
||||
# would incorrectly shift our insertion index below actual code definitions.
|
||||
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
|
||||
break
|
||||
|
||||
return insert_index
|
||||
|
||||
|
||||
class GlobalAssignmentTransformer(cst.CSTTransformer):
|
||||
"""Transforms global assignments in the original file with those from the new file."""
|
||||
|
||||
|
|
@ -122,32 +149,6 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
|
||||
return updated_node
|
||||
|
||||
def _find_insertion_index(self, updated_node: cst.Module) -> int:
|
||||
"""Find the position of the last import statement in the top-level of the module."""
|
||||
insert_index = 0
|
||||
for i, stmt in enumerate(updated_node.body):
|
||||
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
|
||||
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
|
||||
)
|
||||
|
||||
is_conditional_import = isinstance(stmt, cst.If) and all(
|
||||
isinstance(inner, cst.SimpleStatementLine)
|
||||
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
|
||||
for inner in stmt.body.body
|
||||
)
|
||||
|
||||
if is_top_level_import or is_conditional_import:
|
||||
insert_index = i + 1
|
||||
|
||||
# Stop scanning once we reach a class or function definition.
|
||||
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
|
||||
# Without this check, a stray import later in the file
|
||||
# would incorrectly shift our insertion index below actual code definitions.
|
||||
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
|
||||
break
|
||||
|
||||
return insert_index
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
# Add any new assignments that weren't in the original file
|
||||
new_statements = list(updated_node.body)
|
||||
|
|
@ -161,7 +162,7 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
|
||||
if assignments_to_append:
|
||||
# after last top-level imports
|
||||
insert_index = self._find_insertion_index(updated_node)
|
||||
insert_index = find_insertion_index_after_imports(updated_node)
|
||||
|
||||
assignment_lines = [
|
||||
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
|
||||
|
|
|
|||
|
|
@ -3,13 +3,18 @@ from __future__ import annotations
|
|||
import ast
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
|
||||
import libcst as cst
|
||||
from libcst.metadata import PositionProvider
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
|
||||
from codeflash.code_utils.code_extractor import (
|
||||
add_global_assignments,
|
||||
add_needed_imports_from_module,
|
||||
find_insertion_index_after_imports,
|
||||
)
|
||||
from codeflash.code_utils.config_parser import find_conftest_files
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
from codeflash.code_utils.line_profile_utils import ImportAdder
|
||||
|
|
@ -249,6 +254,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
] = {} # keys are (class_name, function_name)
|
||||
self.new_functions: list[cst.FunctionDef] = []
|
||||
self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
|
||||
self.new_classes: list[cst.ClassDef] = []
|
||||
self.current_class = None
|
||||
self.modified_init_functions: dict[str, cst.FunctionDef] = {}
|
||||
|
||||
|
|
@ -271,6 +277,11 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
self.current_class = node.name.value
|
||||
|
||||
parents = (FunctionParent(name=node.name.value, type="ClassDef"),)
|
||||
|
||||
# check if the class is new
|
||||
if (node.name.value, ()) not in self.preexisting_objects:
|
||||
self.new_classes.append(node)
|
||||
|
||||
for child_node in node.body.body:
|
||||
if (
|
||||
self.preexisting_objects
|
||||
|
|
@ -290,6 +301,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
def __init__(
|
||||
self,
|
||||
modified_functions: Optional[dict[tuple[str | None, str], cst.FunctionDef]] = None,
|
||||
new_classes: Optional[list[cst.ClassDef]] = None,
|
||||
new_functions: Optional[list[cst.FunctionDef]] = None,
|
||||
new_class_functions: Optional[dict[str, list[cst.FunctionDef]]] = None,
|
||||
modified_init_functions: Optional[dict[str, cst.FunctionDef]] = None,
|
||||
|
|
@ -297,6 +309,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
super().__init__()
|
||||
self.modified_functions = modified_functions if modified_functions is not None else {}
|
||||
self.new_functions = new_functions if new_functions is not None else []
|
||||
self.new_classes = new_classes if new_classes is not None else []
|
||||
self.new_class_functions = new_class_functions if new_class_functions is not None else defaultdict(list)
|
||||
self.modified_init_functions: dict[str, cst.FunctionDef] = (
|
||||
modified_init_functions if modified_init_functions is not None else {}
|
||||
|
|
@ -335,19 +348,33 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
node = updated_node
|
||||
max_function_index = None
|
||||
class_index = None
|
||||
max_class_index = None
|
||||
for index, _node in enumerate(node.body):
|
||||
if isinstance(_node, cst.FunctionDef):
|
||||
max_function_index = index
|
||||
if isinstance(_node, cst.ClassDef):
|
||||
class_index = index
|
||||
max_class_index = index
|
||||
|
||||
if self.new_classes:
|
||||
existing_class_names = {_node.name.value for _node in node.body if isinstance(_node, cst.ClassDef)}
|
||||
|
||||
unique_classes = [
|
||||
new_class for new_class in self.new_classes if new_class.name.value not in existing_class_names
|
||||
]
|
||||
if unique_classes:
|
||||
new_classes_insertion_idx = max_class_index or find_insertion_index_after_imports(node)
|
||||
new_body = list(
|
||||
chain(node.body[:new_classes_insertion_idx], unique_classes, node.body[new_classes_insertion_idx:])
|
||||
)
|
||||
node = node.with_changes(body=new_body)
|
||||
|
||||
if max_function_index is not None:
|
||||
node = node.with_changes(
|
||||
body=(*node.body[: max_function_index + 1], *self.new_functions, *node.body[max_function_index + 1 :])
|
||||
)
|
||||
elif class_index is not None:
|
||||
elif max_class_index is not None:
|
||||
node = node.with_changes(
|
||||
body=(*node.body[: class_index + 1], *self.new_functions, *node.body[class_index + 1 :])
|
||||
body=(*node.body[: max_class_index + 1], *self.new_functions, *node.body[max_class_index + 1 :])
|
||||
)
|
||||
else:
|
||||
node = node.with_changes(body=(*self.new_functions, *node.body))
|
||||
|
|
@ -373,18 +400,20 @@ def replace_functions_in_file(
|
|||
parsed_function_names.append((class_name, function_name))
|
||||
|
||||
# Collect functions we want to modify from the optimized code
|
||||
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
|
||||
optimized_module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
|
||||
original_module = cst.parse_module(source_code)
|
||||
|
||||
visitor = OptimFunctionCollector(preexisting_objects, set(parsed_function_names))
|
||||
module.visit(visitor)
|
||||
optimized_module.visit(visitor)
|
||||
|
||||
# Replace these functions in the original code
|
||||
transformer = OptimFunctionReplacer(
|
||||
modified_functions=visitor.modified_functions,
|
||||
new_classes=visitor.new_classes,
|
||||
new_functions=visitor.new_functions,
|
||||
new_class_functions=visitor.new_class_functions,
|
||||
modified_init_functions=visitor.modified_init_functions,
|
||||
)
|
||||
original_module = cst.parse_module(source_code)
|
||||
modified_tree = original_module.visit(transformer)
|
||||
return modified_tree.code
|
||||
|
||||
|
|
|
|||
|
|
@ -215,7 +215,7 @@ class NewClass:
|
|||
return other_function(self.name)
|
||||
def new_function2(value):
|
||||
return value
|
||||
"""
|
||||
"""
|
||||
|
||||
original_code = """import libcst as cst
|
||||
from typing import Mandatory
|
||||
|
|
@ -230,19 +230,28 @@ def other_function(st):
|
|||
|
||||
print("Salut monde")
|
||||
"""
|
||||
expected = """from typing import Mandatory
|
||||
expected = """import libcst as cst
|
||||
from typing import Mandatory
|
||||
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def new_function(self, value: cst.Name):
|
||||
return other_function(self.name)
|
||||
def new_function2(value):
|
||||
return value
|
||||
|
||||
print("Au revoir")
|
||||
|
||||
def yet_another_function(values):
|
||||
return len(values)
|
||||
|
||||
def other_function(st):
|
||||
return(st * 2)
|
||||
|
||||
def totally_new_function(value):
|
||||
return value
|
||||
|
||||
def other_function(st):
|
||||
return(st * 2)
|
||||
|
||||
print("Salut monde")
|
||||
"""
|
||||
|
||||
|
|
@ -279,7 +288,7 @@ class NewClass:
|
|||
return other_function(self.name)
|
||||
def new_function2(value):
|
||||
return value
|
||||
"""
|
||||
"""
|
||||
|
||||
original_code = """import libcst as cst
|
||||
from typing import Mandatory
|
||||
|
|
@ -296,17 +305,25 @@ print("Salut monde")
|
|||
"""
|
||||
expected = """from typing import Mandatory
|
||||
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def new_function(self, value):
|
||||
return other_function(self.name)
|
||||
def new_function2(value):
|
||||
return value
|
||||
|
||||
print("Au revoir")
|
||||
|
||||
def yet_another_function(values):
|
||||
return len(values) + 2
|
||||
|
||||
def other_function(st):
|
||||
return(st * 2)
|
||||
|
||||
def totally_new_function(value):
|
||||
return value
|
||||
|
||||
def other_function(st):
|
||||
return(st * 2)
|
||||
|
||||
print("Salut monde")
|
||||
"""
|
||||
|
||||
|
|
@ -3619,4 +3636,110 @@ async def task():
|
|||
await asyncio.sleep(1)
|
||||
return "done"
|
||||
'''
|
||||
assert is_zero_diff(original_code, optimized_code)
|
||||
assert is_zero_diff(original_code, optimized_code)
|
||||
|
||||
|
||||
|
||||
def test_code_replacement_with_new_helper_class() -> None:
|
||||
optim_code = """from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Iterator, Sequence
|
||||
|
||||
from bokeh.models import HoverTool, Plot, Tool
|
||||
|
||||
|
||||
# Move the Item dataclass to module-level to avoid redefining it on every function call
|
||||
@dataclass(frozen=True)
|
||||
class _RepeatedToolItem:
|
||||
obj: Tool
|
||||
properties: dict[str, Any]
|
||||
|
||||
def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
|
||||
key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
|
||||
# Pre-collect properties for all objects by group to avoid repeated calls
|
||||
for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
|
||||
grouped = list(group)
|
||||
n = len(grouped)
|
||||
if n > 1:
|
||||
# Precompute all properties once for this group
|
||||
props = [_RepeatedToolItem(obj, obj.properties_with_values()) for obj in grouped]
|
||||
i = 0
|
||||
while i < len(props) - 1:
|
||||
head = props[i]
|
||||
for j in range(i+1, len(props)):
|
||||
item = props[j]
|
||||
if item.properties == head.properties:
|
||||
yield item.obj
|
||||
i += 1
|
||||
"""
|
||||
|
||||
original_code = """from __future__ import annotations
|
||||
import itertools
|
||||
import re
|
||||
from bokeh.models import HoverTool, Plot, Tool
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Iterator, Sequence
|
||||
|
||||
def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
|
||||
@dataclass(frozen=True)
|
||||
class Item:
|
||||
obj: Tool
|
||||
properties: dict[str, Any]
|
||||
|
||||
key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
|
||||
|
||||
for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
|
||||
rest = [ Item(obj, obj.properties_with_values()) for obj in group ]
|
||||
while len(rest) > 1:
|
||||
head, *rest = rest
|
||||
for item in rest:
|
||||
if item.properties == head.properties:
|
||||
yield item.obj
|
||||
"""
|
||||
|
||||
expected = """from __future__ import annotations
|
||||
import itertools
|
||||
from bokeh.models import Tool
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Iterator
|
||||
|
||||
|
||||
# Move the Item dataclass to module-level to avoid redefining it on every function call
|
||||
@dataclass(frozen=True)
|
||||
class _RepeatedToolItem:
|
||||
obj: Tool
|
||||
properties: dict[str, Any]
|
||||
|
||||
def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
|
||||
key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
|
||||
# Pre-collect properties for all objects by group to avoid repeated calls
|
||||
for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
|
||||
grouped = list(group)
|
||||
n = len(grouped)
|
||||
if n > 1:
|
||||
# Precompute all properties once for this group
|
||||
props = [_RepeatedToolItem(obj, obj.properties_with_values()) for obj in grouped]
|
||||
i = 0
|
||||
while i < len(props) - 1:
|
||||
head = props[i]
|
||||
for j in range(i+1, len(props)):
|
||||
item = props[j]
|
||||
if item.properties == head.properties:
|
||||
yield item.obj
|
||||
i += 1
|
||||
"""
|
||||
|
||||
function_names: list[str] = ["_collect_repeated_tools"]
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
optimized_code=optim_code,
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
|
|
|||
Loading…
Reference in a new issue