Modified code replacer to work file by file. now compatible with new code context extractor.

This commit is contained in:
Alvin Ryanputra 2025-01-08 14:56:53 -08:00
parent 334e2a2952
commit c41f710f5f
9 changed files with 291 additions and 246 deletions

View file

@ -1,20 +1,19 @@
from __future__ import annotations
import ast
from collections import defaultdict
from functools import lru_cache
from typing import TYPE_CHECKING, TypeVar
import libcst as cst
from codeflash.cli_cmds.console import console, logger
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from codeflash.models.models import FunctionParent
if TYPE_CHECKING:
from pathlib import Path
from libcst import FunctionDef
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, OptimizedCandidate, ValidCode
@ -39,93 +38,89 @@ class OptimFunctionCollector(cst.CSTVisitor):
def __init__(
self,
function_name: str,
class_name: str | None,
contextual_functions: set[tuple[str, str]],
preexisting_objects: list[tuple[str, list[FunctionParent]]] | None = None,
function_names: set[tuple[str | None, str]] | None = None,
) -> None:
super().__init__()
if preexisting_objects is None:
preexisting_objects = []
self.function_name = function_name
self.class_name = class_name
self.optim_body: FunctionDef | None = None
self.optim_new_class_functions: list[cst.FunctionDef] = []
self.optim_new_functions: list[cst.FunctionDef] = []
self.preexisting_objects = preexisting_objects
self.contextual_functions = contextual_functions.union({(self.class_name, self.function_name)})
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else []
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
parent2 = None
try:
if parent is not None and isinstance(parent, cst.Module):
parent2 = self.get_metadata(cst.metadata.ParentNodeProvider, parent)
except:
pass
if node.name.value == self.function_name:
self.optim_body = node
self.function_names = function_names # set of (class_name, function_name)
self.modified_functions: dict[
tuple[str | None, str], cst.FunctionDef
] = {} # keys are (class_name, function_name)
self.new_functions: list[cst.FunctionDef] = []
self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
self.current_class = None
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
if (self.current_class, node.name.value) in self.function_names:
self.modified_functions[(self.current_class, node.name.value)] = node
elif (
self.preexisting_objects
and (node.name.value, []) not in self.preexisting_objects
and (isinstance(parent, cst.Module) or (parent2 is not None and not isinstance(parent2, cst.ClassDef)))
and self.current_class is None
):
self.optim_new_functions.append(node)
self.new_functions.append(node)
return False
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
if self.current_class:
return False # If already in a class, do not recurse deeper
self.current_class = node.name.value
def visit_ClassDef_body(self, node: cst.ClassDef) -> None:
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
for child_node in node.body.body:
if (
self.preexisting_objects
and isinstance(child_node, cst.FunctionDef)
and (node.name.value, child_node.name.value) not in self.contextual_functions
and (child_node.name.value, parents) not in self.preexisting_objects
):
self.optim_new_class_functions.append(child_node)
self.new_class_functions[node.name.value].append(child_node)
return True
def leave_ClassDef(self, node: cst.ClassDef) -> None:
if self.current_class:
self.current_class = None
class OptimFunctionReplacer(cst.CSTTransformer):
def __init__(
self,
function_name: str,
optim_body: cst.FunctionDef,
optim_new_class_functions: list[cst.FunctionDef],
optim_new_functions: list[cst.FunctionDef],
class_name: str | None = None,
modified_functions: dict[tuple[str | None, str], cst.FunctionDef] = {},
new_functions: list[cst.FunctionDef] = [],
new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list),
) -> None:
super().__init__()
self.function_name = function_name
self.optim_body = optim_body
self.optim_new_class_functions = optim_new_class_functions
self.optim_new_functions = optim_new_functions
self.class_name = class_name
self.depth: int = 0
self.in_class: bool = False
self.modified_functions = modified_functions
self.new_functions = new_functions
self.new_class_functions = new_class_functions
self.current_class = None
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
return False
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
if original_node.name.value == self.function_name and (self.depth == 0 or (self.depth == 1 and self.in_class)):
return updated_node.with_changes(body=self.optim_body.body, decorators=self.optim_body.decorators)
if (self.current_class, original_node.name.value) in self.modified_functions:
node = self.modified_functions[(self.current_class, original_node.name.value)]
return updated_node.with_changes(body=node.body, decorators=node.decorators)
return updated_node
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
self.depth += 1
if self.in_class:
return False
self.in_class = (self.depth == 1) and (node.name.value == self.class_name)
return self.in_class
if self.current_class:
return False # If already in a class, do not recurse deeper
self.current_class = node.name.value
return True
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
self.depth -= 1
if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name):
self.in_class = False
return updated_node.with_changes(
body=updated_node.body.with_changes(
body=(list(updated_node.body.body) + self.optim_new_class_functions)
if self.current_class and self.current_class == original_node.name.value:
self.current_class = None
if original_node.name.value in self.new_class_functions:
return updated_node.with_changes(
body=updated_node.body.with_changes(
body=(list(updated_node.body.body) + list(self.new_class_functions[original_node.name.value]))
)
)
)
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
@ -139,18 +134,14 @@ class OptimFunctionReplacer(cst.CSTTransformer):
class_index = index
if max_function_index is not None:
node = node.with_changes(
body=(
*node.body[: max_function_index + 1],
*self.optim_new_functions,
*node.body[max_function_index + 1 :],
)
body=(*node.body[: max_function_index + 1], *self.new_functions, *node.body[max_function_index + 1 :])
)
elif class_index is not None:
node = node.with_changes(
body=(*node.body[: class_index + 1], *self.optim_new_functions, *node.body[class_index + 1 :])
body=(*node.body[: class_index + 1], *self.new_functions, *node.body[class_index + 1 :])
)
else:
node = node.with_changes(body=(*self.optim_new_functions, *node.body))
node = node.with_changes(body=(*self.new_functions, *node.body))
return node
@ -159,7 +150,6 @@ def replace_functions_in_file(
original_function_names: list[str],
optimized_code: str,
preexisting_objects: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
) -> str:
parsed_function_names = []
for original_function_name in original_function_names:
@ -171,34 +161,22 @@ def replace_functions_in_file(
msg = f"Unable to find {original_function_name}. Returning unchanged source code."
logger.error(msg)
return source_code
parsed_function_names.append((function_name, class_name))
parsed_function_names.append((class_name, function_name))
# Collect functions we want to modify from the optimized code
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
visitor = OptimFunctionCollector(preexisting_objects, set(parsed_function_names))
module.visit(visitor)
for function_name, class_name in parsed_function_names:
visitor = OptimFunctionCollector(function_name, class_name, contextual_functions, preexisting_objects)
module.visit(visitor)
if visitor.optim_body is None and not preexisting_objects:
continue
if visitor.optim_body is None:
msg = f"Unable to find function {function_name} in optimized code. Returning unchanged source code."
logger.error(msg)
console.rule()
return source_code
transformer = OptimFunctionReplacer(
visitor.function_name,
visitor.optim_body,
visitor.optim_new_class_functions,
visitor.optim_new_functions,
class_name=class_name,
)
original_module = cst.parse_module(source_code)
modified_tree = original_module.visit(transformer)
source_code = modified_tree.code
return source_code
# Replace these functions in the original code
transformer = OptimFunctionReplacer(
modified_functions=visitor.modified_functions,
new_functions=visitor.new_functions,
new_class_functions=visitor.new_class_functions,
)
original_module = cst.parse_module(source_code)
modified_tree = original_module.visit(transformer)
return modified_tree.code
def replace_functions_and_add_imports(
@ -208,14 +186,11 @@ def replace_functions_and_add_imports(
file_path_of_module_with_function_to_optimize: Path,
module_abspath: Path,
preexisting_objects: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
project_root_path: Path,
) -> str:
return add_needed_imports_from_module(
optimized_code,
replace_functions_in_file(
source_code, function_names, optimized_code, preexisting_objects, contextual_functions
),
replace_functions_in_file(source_code, function_names, optimized_code, preexisting_objects),
file_path_of_module_with_function_to_optimize,
module_abspath,
project_root_path,
@ -228,7 +203,6 @@ def replace_function_definitions_in_module(
file_path_of_module_with_function_to_optimize: Path,
module_abspath: Path,
preexisting_objects: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
project_root_path: Path,
) -> bool:
source_code: str = module_abspath.read_text(encoding="utf8")
@ -239,7 +213,6 @@ def replace_function_definitions_in_module(
file_path_of_module_with_function_to_optimize,
module_abspath,
preexisting_objects,
contextual_functions,
project_root_path,
)
if is_zero_diff(source_code, new_code):
@ -268,7 +241,6 @@ def replace_optimized_code(
function_to_optimize.file_path,
function_to_optimize.file_path,
code_context.preexisting_objects,
code_context.contextual_dunder_methods,
project_root,
)
for candidate in candidates
@ -298,7 +270,6 @@ def replace_optimized_code(
function_to_optimize.file_path,
module_path,
[],
code_context.contextual_dunder_methods,
project_root,
)
for module_path in module_paths

View file

@ -11,6 +11,8 @@ from codeflash.cli_cmds.console import logger
def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
if not full_qualified_name:
raise ValueError("full_qualified_name cannot be empty")
if not full_qualified_name.startswith(module_name):
msg = f"{full_qualified_name} does not start with {module_name}"
raise ValueError(msg)

View file

@ -2,6 +2,7 @@ from __future__ import annotations
import os
from collections import defaultdict
from itertools import chain
from pathlib import Path
import jedi
@ -11,24 +12,23 @@ from jedi.api.classes import Name
from libcst import CSTNode
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeString, CodeStringsMarkdown
from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource
from codeflash.optimization.function_context import belongs_to_function_qualified
def get_code_optimization_context(
function_to_optimize: FunctionToOptimize, project_root_path: Path, token_limit: int = 8000
) -> tuple[str, str]:
) -> CodeOptimizationContext:
# Get qualified names and fully qualified names(fqn) of helpers
helpers_of_fto, helpers_of_fto_fqn = get_file_path_to_helper_functions_dict(
helpers_of_fto, helpers_of_fto_fqn, helpers_of_fto_obj_list = get_file_path_to_helper_functions_dict(
{function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path
)
helpers_of_helpers, helpers_of_helpers_fqn = get_file_path_to_helper_functions_dict(
helpers_of_helpers, helpers_of_helpers_fqn, _ = get_file_path_to_helper_functions_dict(
helpers_of_fto, project_root_path
)
# Add function to optimize
helpers_of_fto[function_to_optimize.file_path].add(function_to_optimize.qualified_name)
helpers_of_fto_fqn[function_to_optimize.file_path].add(
@ -36,7 +36,7 @@ def get_code_optimization_context(
)
# Extract code
final_read_writable_code = get_all_read_writable_code(helpers_of_fto, helpers_of_fto_fqn, project_root_path)
final_read_writable_code = get_all_read_writable_code(helpers_of_fto, helpers_of_fto_fqn, project_root_path).code
read_only_code_markdown = get_all_read_only_code_context(
helpers_of_fto,
helpers_of_fto_fqn,
@ -52,10 +52,24 @@ def get_code_optimization_context(
if final_read_writable_tokens > token_limit:
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
# Setup preexisting objects for code replacer TODO: should remove duplicates
preexisting_objects = list(
chain(
find_preexisting_objects(final_read_writable_code),
*(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings),
)
)
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_code_markdown.markdown))
total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens
if total_tokens <= token_limit:
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
return CodeOptimizationContext(
code_to_optimize_with_helpers="",
read_writable_code=CodeString(code=final_read_writable_code).code,
read_only_context_code=read_only_code_markdown.markdown,
helper_functions=helpers_of_fto_obj_list,
preexisting_objects=preexisting_objects,
)
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
# Extract read only code without docstrings
@ -70,15 +84,29 @@ def get_code_optimization_context(
read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_code_no_docstring_markdown.markdown))
total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens
if total_tokens <= token_limit:
return CodeString(code=final_read_writable_code).code, read_only_code_no_docstring_markdown.markdown
return CodeOptimizationContext(
code_to_optimize_with_helpers="",
read_writable_code=CodeString(code=final_read_writable_code).code,
read_only_context_code=read_only_code_no_docstring_markdown.markdown,
contextual_dunder_methods=set(),
helper_functions=helpers_of_fto_obj_list,
preexisting_objects=preexisting_objects,
)
logger.debug("Code context has exceeded token limit, removing read-only code")
return CodeString(code=final_read_writable_code).code, ""
return CodeOptimizationContext(
code_to_optimize_with_helpers="",
read_writable_code=CodeString(code=final_read_writable_code).code,
read_only_context_code="",
contextual_dunder_methods=set(),
helper_functions=helpers_of_fto_obj_list,
preexisting_objects=preexisting_objects,
)
def get_all_read_writable_code(
helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], project_root_path: Path
) -> str:
) -> CodeString:
final_read_writable_code = ""
# Extract code from file paths that contain fto and first degree helpers
for file_path, qualified_function_names in helpers_of_fto.items():
@ -103,7 +131,7 @@ def get_all_read_writable_code(
project_root=project_root_path,
helper_functions_fqn=helpers_of_fto_fqn[file_path],
)
return final_read_writable_code
return CodeString(code=final_read_writable_code)
def get_all_read_only_code_context(
@ -189,9 +217,10 @@ def get_all_read_only_code_context(
def get_file_path_to_helper_functions_dict(
file_path_to_qualified_function_names: dict[Path, set[str]], project_root_path: Path
) -> tuple[dict[Path, set[str]], dict[Path, set[str]]]:
) -> tuple[dict[Path, set[str]], dict[Path, set[str]], list[FunctionSource]]:
file_path_to_helper_function_qualified_names = defaultdict(set)
file_path_to_helper_function_fqn = defaultdict(set)
function_source_list: list[FunctionSource] = []
for file_path in file_path_to_qualified_function_names:
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
file_refs = script.get_names(all_scopes=True, definitions=False, references=True)
@ -229,8 +258,18 @@ def get_file_path_to_helper_functions_dict(
get_qualified_name(definition.module_name, definition.full_name)
)
file_path_to_helper_function_fqn[definition_path].add(definition.full_name)
function_source_list.append(
FunctionSource(
file_path=definition_path,
qualified_name=get_qualified_name(definition.module_name, definition.full_name),
fully_qualified_name=definition.full_name,
only_function_name=definition.name,
source_code=definition.get_line_code(),
jedi_definition=definition,
)
)
return file_path_to_helper_function_qualified_names, file_path_to_helper_function_fqn
return file_path_to_helper_function_qualified_names, file_path_to_helper_function_fqn, function_source_list
def is_dunder_method(name: str) -> bool:

View file

@ -80,7 +80,6 @@ class CodeOptimizationContext(BaseModel):
code_to_optimize_with_helpers: str
read_writable_code: str = Field(min_length=1)
read_only_context_code: str = ""
contextual_dunder_methods: set[tuple[str, str]]
helper_functions: list[FunctionSource]
preexisting_objects: list[tuple[str, list[FunctionParent]]]

View file

@ -22,7 +22,7 @@ from rich.tree import Tree
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code, find_preexisting_objects
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code
from codeflash.code_utils.code_replacer import normalize_code, normalize_node, replace_function_definitions_in_module
from codeflash.code_utils.code_utils import (
cleanup_paths,
@ -668,7 +668,6 @@ class Optimizer:
file_path_of_module_with_function_to_optimize=function_to_optimize_file_path,
module_abspath=function_to_optimize_file_path,
preexisting_objects=code_context.preexisting_objects,
contextual_functions=code_context.contextual_dunder_methods,
project_root_path=self.args.project_root,
)
helper_functions_by_module_abspath = defaultdict(set)
@ -681,8 +680,7 @@ class Optimizer:
optimized_code=optimized_code,
file_path_of_module_with_function_to_optimize=function_to_optimize_file_path,
module_abspath=module_abspath,
preexisting_objects=[],
contextual_functions=code_context.contextual_dunder_methods,
preexisting_objects=code_context.preexisting_objects,
project_root_path=self.args.project_root,
)
return did_update
@ -733,25 +731,19 @@ class Optimizer:
project_root,
helper_functions,
)
preexisting_objects = find_preexisting_objects(code_to_optimize_with_helpers)
contextual_dunder_methods.update(helper_dunder_methods)
# Will eventually refactor to use this function instead of the above
try:
read_writable_code, read_only_context_code = code_context_extractor.get_code_optimization_context(
function_to_optimize, project_root
)
new_code_ctx = code_context_extractor.get_code_optimization_context(function_to_optimize, project_root)
except ValueError as e:
return Failure(str(e))
return Success(
CodeOptimizationContext(
code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,
read_writable_code=read_writable_code,
read_only_context_code=read_only_context_code,
contextual_dunder_methods=contextual_dunder_methods,
helper_functions=helper_functions,
preexisting_objects=preexisting_objects,
read_writable_code=new_code_ctx.read_writable_code,
read_only_context_code=new_code_ctx.read_only_context_code,
helper_functions=new_code_ctx.helper_functions, # only functions that are read writable
preexisting_objects=new_code_ctx.preexisting_objects,
)
)

View file

@ -72,9 +72,8 @@ def test_code_replacement10() -> None:
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize=func_top_optimize, project_root_path=file_path.parent
)
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
from __future__ import annotations
@ -119,10 +118,8 @@ def test_class_method_dependencies() -> None:
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, file_path.parent.resolve()
)
code_ctx = get_code_optimization_context(function_to_optimize, file_path.parent.resolve())
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
from __future__ import annotations
@ -181,9 +178,8 @@ def test_bubble_sort_helper() -> None:
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, Path(__file__).resolve().parent.parent
)
code_ctx = get_code_optimization_context(function_to_optimize, Path(__file__).resolve().parent.parent)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
import math
@ -397,9 +393,8 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, opt.args.project_root
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
@ -599,9 +594,8 @@ class HelperClass:
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, opt.args.project_root
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
class MyClass:
def target_method(self):
@ -680,9 +674,8 @@ class HelperClass:
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, opt.args.project_root
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
expected_read_write_context = """
class MyClass:
@ -759,9 +752,8 @@ class HelperClass:
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, opt.args.project_root
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
expected_read_write_context = """
class MyClass:
@ -826,9 +818,7 @@ class HelperClass:
)
# In this scenario, the read-writable code is too long, so we abort.
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, opt.args.project_root
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
def test_repo_helper() -> None:
@ -843,7 +833,8 @@ def test_repo_helper() -> None:
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(function_to_optimize, project_root)
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
import requests
from globals import API_URL
@ -920,7 +911,8 @@ def test_repo_helper_of_helper() -> None:
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(function_to_optimize, project_root)
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
from transform_utils import DataTransformer
import requests
@ -1006,7 +998,8 @@ def test_repo_helper_of_helper_same_class() -> None:
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(function_to_optimize, project_root)
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
from transform_utils import DataTransformer
@ -1073,7 +1066,8 @@ def test_repo_helper_of_helper_same_file() -> None:
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(function_to_optimize, project_root)
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
from transform_utils import DataTransformer
@ -1137,7 +1131,8 @@ def test_repo_helper_all_same_file() -> None:
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(function_to_optimize, project_root)
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
class DataTransformer:
@ -1181,7 +1176,8 @@ def test_repo_helper_circular_dependency() -> None:
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(function_to_optimize, project_root)
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
from transform_utils import DataTransformer
from code_to_optimize.code_directories.retriever.utils import DataProcessor

View file

@ -6,7 +6,7 @@ from argparse import Namespace
from collections import defaultdict
from pathlib import Path
from codeflash.code_utils.code_extractor import delete___future___aliased_imports
from codeflash.code_utils.code_extractor import delete___future___aliased_imports, find_preexisting_objects
from codeflash.code_utils.code_replacer import (
is_zero_diff,
replace_functions_and_add_imports,
@ -74,10 +74,7 @@ print("Hello world")
"""
function_name: str = "NewClass.new_function"
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
("new_function", [FunctionParent(name="NewClass", type="ClassDef")])
]
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
@ -85,7 +82,6 @@ print("Hello world")
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -133,12 +129,14 @@ class NewClass:
def totally_new_function(value):
return value
def other_function(st):
return(st * 2)
print("Hello world")
"""
function_name: str = "NewClass.new_function"
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("new_function", []), ("other_function", [])]
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
@ -146,7 +144,6 @@ print("Hello world")
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -194,12 +191,14 @@ def yet_another_function(values):
def other_function(st):
return(st * 2)
def totally_new_function(value):
return value
print("Salut monde")
"""
function_names: list[str] = ["module.other_function"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
contextual_functions: set[tuple[str, str]] = set()
function_names: list[str] = ["other_function"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
@ -207,7 +206,6 @@ print("Salut monde")
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -258,12 +256,14 @@ def yet_another_function(values):
def other_function(st):
return(st * 2)
def totally_new_function(value):
return value
print("Salut monde")
"""
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
contextual_functions: set[tuple[str, str]] = set()
function_names: list[str] = ["yet_another_function", "other_function"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
@ -271,7 +271,6 @@ print("Salut monde")
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -318,8 +317,7 @@ def supersort(doink):
"""
function_names: list[str] = ["sorter_deps"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("sorter_deps", [])]
contextual_functions: set[tuple[str, str]] = set()
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
@ -327,7 +325,6 @@ def supersort(doink):
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -396,14 +393,14 @@ def blab(st):
print("Not cool")
"""
preexisting_objects = find_preexisting_objects(original_code_main) + find_preexisting_objects(original_code_helper)
new_main_code: str = replace_functions_and_add_imports(
source_code=original_code_main,
function_names=["other_function"],
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=[("other_function", []), ("yet_another_function", []), ("blob", [])],
contextual_functions=set(),
preexisting_objects=preexisting_objects,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_main_code == expected_main
@ -414,8 +411,7 @@ print("Not cool")
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=[],
contextual_functions=set(),
preexisting_objects=preexisting_objects,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_helper_code == expected_helper
@ -602,14 +598,8 @@ class CacheConfig(BaseConfig):
)
"""
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
parents = [FunctionParent(name="CacheConfig", type="ClassDef")]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("from_config", parents)]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
contextual_functions: set[tuple[str, str]] = {
("CacheSimilarityEvalConfig", "__init__"),
("CacheConfig", "__init__"),
("CacheInitConfig", "__init__"),
}
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
@ -617,7 +607,6 @@ class CacheConfig(BaseConfig):
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -681,10 +670,7 @@ def test_test_libcst_code_replacement8() -> None:
return np.sum(a != b) / a.size
'''
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
("_hamming_distance", [FunctionParent("_EmbeddingDistanceChainMixin", "ClassDef")])
]
contextual_functions: set[tuple[str, str]] = set()
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
@ -692,7 +678,6 @@ def test_test_libcst_code_replacement8() -> None:
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -738,10 +723,8 @@ def totally_new_function(value: Optional[str]):
print("Hello world")
"""
parents = [FunctionParent(name="NewClass", type="ClassDef")]
function_name: str = "NewClass.__init__"
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("__call__", parents)]
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__"), ("NewClass", "__call__")}
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
@ -749,7 +732,6 @@ print("Hello world")
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -847,13 +829,11 @@ def test_code_replacement11() -> None:
function_name: str = "Fu.foo"
parents = [FunctionParent("Fu", "ClassDef")]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("foo", parents), ("real_bar", parents)]
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
source_code=original_code,
original_function_names=[function_name],
optimized_code=optim_code,
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
)
assert new_code == expected_code
@ -888,13 +868,11 @@ def test_code_replacement12() -> None:
'''
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
source_code=original_code,
original_function_names=["Fu.real_bar"],
optimized_code=optim_code,
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
)
assert new_code == expected_code
@ -924,9 +902,8 @@ def test_test_libcst_code_replacement13() -> None:
return self.name
"""
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
function_names: list[str] = ["yet_another_function", "other_function"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
@ -934,7 +911,6 @@ def test_test_libcst_code_replacement13() -> None:
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == original_code
@ -1100,31 +1076,7 @@ class TestResults(BaseModel):
report[test_result.test_type][key] += 1
return report"""
preexisting_objects = [
("__contains__", [FunctionParent(name="TestResults", type="ClassDef")]),
("__len__", [FunctionParent(name="TestResults", type="ClassDef")]),
("__bool__", [FunctionParent(name="TestResults", type="ClassDef")]),
("__eq__", [FunctionParent(name="TestResults", type="ClassDef")]),
("__delitem__", [FunctionParent(name="TestResults", type="ClassDef")]),
("__iter__", [FunctionParent(name="TestResults", type="ClassDef")]),
("__setitem__", [FunctionParent(name="TestResults", type="ClassDef")]),
("__getitem__", [FunctionParent(name="TestResults", type="ClassDef")]),
("get_test_pass_fail_report_by_type", [FunctionParent(name="TestResults", type="ClassDef")]),
("TestType", []),
("TestResults", []),
("to_name", [FunctionParent(name="TestType", type="ClassDef")]),
]
contextual_functions = {
("TestResults", "__bool__"),
("TestResults", "__contains__"),
("TestResults", "__delitem__"),
("TestResults", "__eq__"),
("TestResults", "__getitem__"),
("TestResults", "__iter__"),
("TestResults", "__len__"),
("TestResults", "__setitem__"),
}
preexisting_objects = find_preexisting_objects(original_code)
helper_functions = [
FakeFunctionSource(
@ -1146,7 +1098,6 @@ class TestResults(BaseModel):
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=Path(__file__).parent.resolve(),
)
@ -1162,7 +1113,6 @@ class TestResults(BaseModel):
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=module_abspath,
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=Path(__file__).parent.resolve(),
)
@ -1343,13 +1293,8 @@ def cosine_similarity_top_k(
return ret_idxs, scores
'''
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
("cosine_similarity_top_k", []),
("Matrix", []),
("cosine_similarity", []),
]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
contextual_functions: set[tuple[str, str]] = set()
helper_functions = [
FakeFunctionSource(
file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(),
@ -1376,7 +1321,6 @@ def cosine_similarity_top_k(
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=(Path(__file__).parent / "code_to_optimize").resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=Path(__file__).parent.parent.resolve(),
)
assert (
@ -1436,7 +1380,6 @@ def cosine_similarity_top_k(
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=module_abspath,
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=Path(__file__).parent.parent.resolve(),
)
@ -1599,3 +1542,105 @@ def functionA():
return np.array([1, 2, 3])
'''
assert is_zero_diff(original_code, optim_code_e)
def test_nested_class() -> None:
optim_code = """import libcst as cst
from typing import Optional
class NewClass:
def __init__(self, name):
self.name = str(name)
def __call__(self, value):
return self.name
def new_function2(value):
return cst.ensure_type(value, int)
class NestedClass:
def nested_function(self):
return "I am nested and modified"
"""
original_code = """class NewClass:
def __init__(self, name):
self.name = name
def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
class NestedClass:
def nested_function(self):
return "I am nested"
print("Hello world")
"""
expected = """import libcst as cst
class NewClass:
def __init__(self, name):
self.name = name
def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, int)
class NestedClass:
def nested_function(self):
return "I am nested"
print("Hello world")
"""
function_names: list[str] = [
"NewClass.new_function2",
"NestedClass.nested_function",
] # Nested classes should be ignored, even if provided as target
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
def test_modify_back_to_original() -> None:
optim_code = """class NewClass:
def __init__(self, name):
self.name = name
def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
print("Hello world")
"""
original_code = """class NewClass:
def __init__(self, name):
self.name = name
def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)
print("Hello world")
"""
function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == original_code

View file

@ -61,8 +61,10 @@ class A:
def nested_function(self):
def nested():
return global_dependency_3(1)
return nested() + self.add_two(3)
class B:
def calculate_something_2(self, num):
return num + 1
@ -217,7 +219,6 @@ def test_class_method_dependencies() -> None:
code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.Graph.topologicalSortUtil"
)
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
assert code_context.contextual_dunder_methods == {("Graph", "__init__")}
assert (
code_context.code_to_optimize_with_helpers
== """from collections import defaultdict
@ -303,8 +304,9 @@ def test_recursive_function_context() -> None:
if not is_successful(ctx_result):
pytest.fail()
code_context = ctx_result.unwrap()
assert len(code_context.helper_functions) == 1
assert len(code_context.helper_functions) == 2
assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3"
assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive"
assert (
code_context.code_to_optimize_with_helpers
== """class C:
@ -360,6 +362,7 @@ def test_method_in_method_list_comprehension() -> None:
assert len(helper_functions) == 1
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.A.add_two"
def test_nested_method() -> None:
file_path = pathlib.Path(__file__).resolve()
function_to_optimize = FunctionToOptimize(

View file

@ -237,9 +237,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
if not is_successful(ctx_result):
pytest.fail()
code_context = ctx_result.unwrap()
assert code_context.helper_functions[0].qualified_name == "_R"
assert code_context.helper_functions[1].qualified_name == "AbstractCacheBackend.get_cache_or_call"
assert len(code_context.contextual_dunder_methods) == 2
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
assert (
code_context.code_to_optimize_with_helpers