Modified code replacer to work file by file. now compatible with new code context extractor.
This commit is contained in:
parent
334e2a2952
commit
c41f710f5f
9 changed files with 291 additions and 246 deletions
|
|
@ -1,20 +1,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from libcst import FunctionDef
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext, OptimizedCandidate, ValidCode
|
||||
|
||||
|
|
@ -39,93 +38,89 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
function_name: str,
|
||||
class_name: str | None,
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] | None = None,
|
||||
function_names: set[tuple[str | None, str]] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if preexisting_objects is None:
|
||||
preexisting_objects = []
|
||||
self.function_name = function_name
|
||||
self.class_name = class_name
|
||||
self.optim_body: FunctionDef | None = None
|
||||
self.optim_new_class_functions: list[cst.FunctionDef] = []
|
||||
self.optim_new_functions: list[cst.FunctionDef] = []
|
||||
self.preexisting_objects = preexisting_objects
|
||||
self.contextual_functions = contextual_functions.union({(self.class_name, self.function_name)})
|
||||
self.preexisting_objects = preexisting_objects if preexisting_objects is not None else []
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||
parent2 = None
|
||||
try:
|
||||
if parent is not None and isinstance(parent, cst.Module):
|
||||
parent2 = self.get_metadata(cst.metadata.ParentNodeProvider, parent)
|
||||
except:
|
||||
pass
|
||||
if node.name.value == self.function_name:
|
||||
self.optim_body = node
|
||||
self.function_names = function_names # set of (class_name, function_name)
|
||||
self.modified_functions: dict[
|
||||
tuple[str | None, str], cst.FunctionDef
|
||||
] = {} # keys are (class_name, function_name)
|
||||
self.new_functions: list[cst.FunctionDef] = []
|
||||
self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
|
||||
self.current_class = None
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
|
||||
if (self.current_class, node.name.value) in self.function_names:
|
||||
self.modified_functions[(self.current_class, node.name.value)] = node
|
||||
elif (
|
||||
self.preexisting_objects
|
||||
and (node.name.value, []) not in self.preexisting_objects
|
||||
and (isinstance(parent, cst.Module) or (parent2 is not None and not isinstance(parent2, cst.ClassDef)))
|
||||
and self.current_class is None
|
||||
):
|
||||
self.optim_new_functions.append(node)
|
||||
self.new_functions.append(node)
|
||||
return False
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
if self.current_class:
|
||||
return False # If already in a class, do not recurse deeper
|
||||
self.current_class = node.name.value
|
||||
|
||||
def visit_ClassDef_body(self, node: cst.ClassDef) -> None:
|
||||
parents = [FunctionParent(name=node.name.value, type="ClassDef")]
|
||||
for child_node in node.body.body:
|
||||
if (
|
||||
self.preexisting_objects
|
||||
and isinstance(child_node, cst.FunctionDef)
|
||||
and (node.name.value, child_node.name.value) not in self.contextual_functions
|
||||
and (child_node.name.value, parents) not in self.preexisting_objects
|
||||
):
|
||||
self.optim_new_class_functions.append(child_node)
|
||||
self.new_class_functions[node.name.value].append(child_node)
|
||||
return True
|
||||
|
||||
def leave_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
if self.current_class:
|
||||
self.current_class = None
|
||||
|
||||
|
||||
class OptimFunctionReplacer(cst.CSTTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
function_name: str,
|
||||
optim_body: cst.FunctionDef,
|
||||
optim_new_class_functions: list[cst.FunctionDef],
|
||||
optim_new_functions: list[cst.FunctionDef],
|
||||
class_name: str | None = None,
|
||||
modified_functions: dict[tuple[str | None, str], cst.FunctionDef] = {},
|
||||
new_functions: list[cst.FunctionDef] = [],
|
||||
new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.function_name = function_name
|
||||
self.optim_body = optim_body
|
||||
self.optim_new_class_functions = optim_new_class_functions
|
||||
self.optim_new_functions = optim_new_functions
|
||||
self.class_name = class_name
|
||||
self.depth: int = 0
|
||||
self.in_class: bool = False
|
||||
self.modified_functions = modified_functions
|
||||
self.new_functions = new_functions
|
||||
self.new_class_functions = new_class_functions
|
||||
self.current_class = None
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
|
||||
return False
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
if original_node.name.value == self.function_name and (self.depth == 0 or (self.depth == 1 and self.in_class)):
|
||||
return updated_node.with_changes(body=self.optim_body.body, decorators=self.optim_body.decorators)
|
||||
if (self.current_class, original_node.name.value) in self.modified_functions:
|
||||
node = self.modified_functions[(self.current_class, original_node.name.value)]
|
||||
return updated_node.with_changes(body=node.body, decorators=node.decorators)
|
||||
|
||||
return updated_node
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
self.depth += 1
|
||||
if self.in_class:
|
||||
return False
|
||||
self.in_class = (self.depth == 1) and (node.name.value == self.class_name)
|
||||
return self.in_class
|
||||
if self.current_class:
|
||||
return False # If already in a class, do not recurse deeper
|
||||
self.current_class = node.name.value
|
||||
return True
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
|
||||
self.depth -= 1
|
||||
if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name):
|
||||
self.in_class = False
|
||||
return updated_node.with_changes(
|
||||
body=updated_node.body.with_changes(
|
||||
body=(list(updated_node.body.body) + self.optim_new_class_functions)
|
||||
if self.current_class and self.current_class == original_node.name.value:
|
||||
self.current_class = None
|
||||
if original_node.name.value in self.new_class_functions:
|
||||
return updated_node.with_changes(
|
||||
body=updated_node.body.with_changes(
|
||||
body=(list(updated_node.body.body) + list(self.new_class_functions[original_node.name.value]))
|
||||
)
|
||||
)
|
||||
)
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
|
|
@ -139,18 +134,14 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
class_index = index
|
||||
if max_function_index is not None:
|
||||
node = node.with_changes(
|
||||
body=(
|
||||
*node.body[: max_function_index + 1],
|
||||
*self.optim_new_functions,
|
||||
*node.body[max_function_index + 1 :],
|
||||
)
|
||||
body=(*node.body[: max_function_index + 1], *self.new_functions, *node.body[max_function_index + 1 :])
|
||||
)
|
||||
elif class_index is not None:
|
||||
node = node.with_changes(
|
||||
body=(*node.body[: class_index + 1], *self.optim_new_functions, *node.body[class_index + 1 :])
|
||||
body=(*node.body[: class_index + 1], *self.new_functions, *node.body[class_index + 1 :])
|
||||
)
|
||||
else:
|
||||
node = node.with_changes(body=(*self.optim_new_functions, *node.body))
|
||||
node = node.with_changes(body=(*self.new_functions, *node.body))
|
||||
return node
|
||||
|
||||
|
||||
|
|
@ -159,7 +150,6 @@ def replace_functions_in_file(
|
|||
original_function_names: list[str],
|
||||
optimized_code: str,
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
) -> str:
|
||||
parsed_function_names = []
|
||||
for original_function_name in original_function_names:
|
||||
|
|
@ -171,34 +161,22 @@ def replace_functions_in_file(
|
|||
msg = f"Unable to find {original_function_name}. Returning unchanged source code."
|
||||
logger.error(msg)
|
||||
return source_code
|
||||
parsed_function_names.append((function_name, class_name))
|
||||
parsed_function_names.append((class_name, function_name))
|
||||
|
||||
# Collect functions we want to modify from the optimized code
|
||||
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
|
||||
visitor = OptimFunctionCollector(preexisting_objects, set(parsed_function_names))
|
||||
module.visit(visitor)
|
||||
|
||||
for function_name, class_name in parsed_function_names:
|
||||
visitor = OptimFunctionCollector(function_name, class_name, contextual_functions, preexisting_objects)
|
||||
module.visit(visitor)
|
||||
|
||||
if visitor.optim_body is None and not preexisting_objects:
|
||||
continue
|
||||
if visitor.optim_body is None:
|
||||
msg = f"Unable to find function {function_name} in optimized code. Returning unchanged source code."
|
||||
logger.error(msg)
|
||||
console.rule()
|
||||
return source_code
|
||||
|
||||
transformer = OptimFunctionReplacer(
|
||||
visitor.function_name,
|
||||
visitor.optim_body,
|
||||
visitor.optim_new_class_functions,
|
||||
visitor.optim_new_functions,
|
||||
class_name=class_name,
|
||||
)
|
||||
original_module = cst.parse_module(source_code)
|
||||
modified_tree = original_module.visit(transformer)
|
||||
source_code = modified_tree.code
|
||||
|
||||
return source_code
|
||||
# Replace these functions in the original code
|
||||
transformer = OptimFunctionReplacer(
|
||||
modified_functions=visitor.modified_functions,
|
||||
new_functions=visitor.new_functions,
|
||||
new_class_functions=visitor.new_class_functions,
|
||||
)
|
||||
original_module = cst.parse_module(source_code)
|
||||
modified_tree = original_module.visit(transformer)
|
||||
return modified_tree.code
|
||||
|
||||
|
||||
def replace_functions_and_add_imports(
|
||||
|
|
@ -208,14 +186,11 @@ def replace_functions_and_add_imports(
|
|||
file_path_of_module_with_function_to_optimize: Path,
|
||||
module_abspath: Path,
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
project_root_path: Path,
|
||||
) -> str:
|
||||
return add_needed_imports_from_module(
|
||||
optimized_code,
|
||||
replace_functions_in_file(
|
||||
source_code, function_names, optimized_code, preexisting_objects, contextual_functions
|
||||
),
|
||||
replace_functions_in_file(source_code, function_names, optimized_code, preexisting_objects),
|
||||
file_path_of_module_with_function_to_optimize,
|
||||
module_abspath,
|
||||
project_root_path,
|
||||
|
|
@ -228,7 +203,6 @@ def replace_function_definitions_in_module(
|
|||
file_path_of_module_with_function_to_optimize: Path,
|
||||
module_abspath: Path,
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
project_root_path: Path,
|
||||
) -> bool:
|
||||
source_code: str = module_abspath.read_text(encoding="utf8")
|
||||
|
|
@ -239,7 +213,6 @@ def replace_function_definitions_in_module(
|
|||
file_path_of_module_with_function_to_optimize,
|
||||
module_abspath,
|
||||
preexisting_objects,
|
||||
contextual_functions,
|
||||
project_root_path,
|
||||
)
|
||||
if is_zero_diff(source_code, new_code):
|
||||
|
|
@ -268,7 +241,6 @@ def replace_optimized_code(
|
|||
function_to_optimize.file_path,
|
||||
function_to_optimize.file_path,
|
||||
code_context.preexisting_objects,
|
||||
code_context.contextual_dunder_methods,
|
||||
project_root,
|
||||
)
|
||||
for candidate in candidates
|
||||
|
|
@ -298,7 +270,6 @@ def replace_optimized_code(
|
|||
function_to_optimize.file_path,
|
||||
module_path,
|
||||
[],
|
||||
code_context.contextual_dunder_methods,
|
||||
project_root,
|
||||
)
|
||||
for module_path in module_paths
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from codeflash.cli_cmds.console import logger
|
|||
|
||||
|
||||
def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
|
||||
if not full_qualified_name:
|
||||
raise ValueError("full_qualified_name cannot be empty")
|
||||
if not full_qualified_name.startswith(module_name):
|
||||
msg = f"{full_qualified_name} does not start with {module_name}"
|
||||
raise ValueError(msg)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import jedi
|
||||
|
|
@ -11,24 +12,23 @@ from jedi.api.classes import Name
|
|||
from libcst import CSTNode
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
|
||||
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource
|
||||
from codeflash.optimization.function_context import belongs_to_function_qualified
|
||||
|
||||
|
||||
def get_code_optimization_context(
|
||||
function_to_optimize: FunctionToOptimize, project_root_path: Path, token_limit: int = 8000
|
||||
) -> tuple[str, str]:
|
||||
) -> CodeOptimizationContext:
|
||||
# Get qualified names and fully qualified names(fqn) of helpers
|
||||
helpers_of_fto, helpers_of_fto_fqn = get_file_path_to_helper_functions_dict(
|
||||
helpers_of_fto, helpers_of_fto_fqn, helpers_of_fto_obj_list = get_file_path_to_helper_functions_dict(
|
||||
{function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path
|
||||
)
|
||||
helpers_of_helpers, helpers_of_helpers_fqn = get_file_path_to_helper_functions_dict(
|
||||
helpers_of_helpers, helpers_of_helpers_fqn, _ = get_file_path_to_helper_functions_dict(
|
||||
helpers_of_fto, project_root_path
|
||||
)
|
||||
|
||||
# Add function to optimize
|
||||
helpers_of_fto[function_to_optimize.file_path].add(function_to_optimize.qualified_name)
|
||||
helpers_of_fto_fqn[function_to_optimize.file_path].add(
|
||||
|
|
@ -36,7 +36,7 @@ def get_code_optimization_context(
|
|||
)
|
||||
|
||||
# Extract code
|
||||
final_read_writable_code = get_all_read_writable_code(helpers_of_fto, helpers_of_fto_fqn, project_root_path)
|
||||
final_read_writable_code = get_all_read_writable_code(helpers_of_fto, helpers_of_fto_fqn, project_root_path).code
|
||||
read_only_code_markdown = get_all_read_only_code_context(
|
||||
helpers_of_fto,
|
||||
helpers_of_fto_fqn,
|
||||
|
|
@ -52,10 +52,24 @@ def get_code_optimization_context(
|
|||
if final_read_writable_tokens > token_limit:
|
||||
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
|
||||
|
||||
# Setup preexisting objects for code replacer TODO: should remove duplicates
|
||||
preexisting_objects = list(
|
||||
chain(
|
||||
find_preexisting_objects(final_read_writable_code),
|
||||
*(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings),
|
||||
)
|
||||
)
|
||||
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_code_markdown.markdown))
|
||||
total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens
|
||||
if total_tokens <= token_limit:
|
||||
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
|
||||
return CodeOptimizationContext(
|
||||
code_to_optimize_with_helpers="",
|
||||
read_writable_code=CodeString(code=final_read_writable_code).code,
|
||||
read_only_context_code=read_only_code_markdown.markdown,
|
||||
helper_functions=helpers_of_fto_obj_list,
|
||||
preexisting_objects=preexisting_objects,
|
||||
)
|
||||
|
||||
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
|
||||
|
||||
# Extract read only code without docstrings
|
||||
|
|
@ -70,15 +84,29 @@ def get_code_optimization_context(
|
|||
read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_code_no_docstring_markdown.markdown))
|
||||
total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens
|
||||
if total_tokens <= token_limit:
|
||||
return CodeString(code=final_read_writable_code).code, read_only_code_no_docstring_markdown.markdown
|
||||
return CodeOptimizationContext(
|
||||
code_to_optimize_with_helpers="",
|
||||
read_writable_code=CodeString(code=final_read_writable_code).code,
|
||||
read_only_context_code=read_only_code_no_docstring_markdown.markdown,
|
||||
contextual_dunder_methods=set(),
|
||||
helper_functions=helpers_of_fto_obj_list,
|
||||
preexisting_objects=preexisting_objects,
|
||||
)
|
||||
|
||||
logger.debug("Code context has exceeded token limit, removing read-only code")
|
||||
return CodeString(code=final_read_writable_code).code, ""
|
||||
return CodeOptimizationContext(
|
||||
code_to_optimize_with_helpers="",
|
||||
read_writable_code=CodeString(code=final_read_writable_code).code,
|
||||
read_only_context_code="",
|
||||
contextual_dunder_methods=set(),
|
||||
helper_functions=helpers_of_fto_obj_list,
|
||||
preexisting_objects=preexisting_objects,
|
||||
)
|
||||
|
||||
|
||||
def get_all_read_writable_code(
|
||||
helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], project_root_path: Path
|
||||
) -> str:
|
||||
) -> CodeString:
|
||||
final_read_writable_code = ""
|
||||
# Extract code from file paths that contain fto and first degree helpers
|
||||
for file_path, qualified_function_names in helpers_of_fto.items():
|
||||
|
|
@ -103,7 +131,7 @@ def get_all_read_writable_code(
|
|||
project_root=project_root_path,
|
||||
helper_functions_fqn=helpers_of_fto_fqn[file_path],
|
||||
)
|
||||
return final_read_writable_code
|
||||
return CodeString(code=final_read_writable_code)
|
||||
|
||||
|
||||
def get_all_read_only_code_context(
|
||||
|
|
@ -189,9 +217,10 @@ def get_all_read_only_code_context(
|
|||
|
||||
def get_file_path_to_helper_functions_dict(
|
||||
file_path_to_qualified_function_names: dict[Path, set[str]], project_root_path: Path
|
||||
) -> tuple[dict[Path, set[str]], dict[Path, set[str]]]:
|
||||
) -> tuple[dict[Path, set[str]], dict[Path, set[str]], list[FunctionSource]]:
|
||||
file_path_to_helper_function_qualified_names = defaultdict(set)
|
||||
file_path_to_helper_function_fqn = defaultdict(set)
|
||||
function_source_list: list[FunctionSource] = []
|
||||
for file_path in file_path_to_qualified_function_names:
|
||||
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
|
||||
file_refs = script.get_names(all_scopes=True, definitions=False, references=True)
|
||||
|
|
@ -229,8 +258,18 @@ def get_file_path_to_helper_functions_dict(
|
|||
get_qualified_name(definition.module_name, definition.full_name)
|
||||
)
|
||||
file_path_to_helper_function_fqn[definition_path].add(definition.full_name)
|
||||
function_source_list.append(
|
||||
FunctionSource(
|
||||
file_path=definition_path,
|
||||
qualified_name=get_qualified_name(definition.module_name, definition.full_name),
|
||||
fully_qualified_name=definition.full_name,
|
||||
only_function_name=definition.name,
|
||||
source_code=definition.get_line_code(),
|
||||
jedi_definition=definition,
|
||||
)
|
||||
)
|
||||
|
||||
return file_path_to_helper_function_qualified_names, file_path_to_helper_function_fqn
|
||||
return file_path_to_helper_function_qualified_names, file_path_to_helper_function_fqn, function_source_list
|
||||
|
||||
|
||||
def is_dunder_method(name: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -80,7 +80,6 @@ class CodeOptimizationContext(BaseModel):
|
|||
code_to_optimize_with_helpers: str
|
||||
read_writable_code: str = Field(min_length=1)
|
||||
read_only_context_code: str = ""
|
||||
contextual_dunder_methods: set[tuple[str, str]]
|
||||
helper_functions: list[FunctionSource]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]]
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from rich.tree import Tree
|
|||
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
|
||||
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code, find_preexisting_objects
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code
|
||||
from codeflash.code_utils.code_replacer import normalize_code, normalize_node, replace_function_definitions_in_module
|
||||
from codeflash.code_utils.code_utils import (
|
||||
cleanup_paths,
|
||||
|
|
@ -668,7 +668,6 @@ class Optimizer:
|
|||
file_path_of_module_with_function_to_optimize=function_to_optimize_file_path,
|
||||
module_abspath=function_to_optimize_file_path,
|
||||
preexisting_objects=code_context.preexisting_objects,
|
||||
contextual_functions=code_context.contextual_dunder_methods,
|
||||
project_root_path=self.args.project_root,
|
||||
)
|
||||
helper_functions_by_module_abspath = defaultdict(set)
|
||||
|
|
@ -681,8 +680,7 @@ class Optimizer:
|
|||
optimized_code=optimized_code,
|
||||
file_path_of_module_with_function_to_optimize=function_to_optimize_file_path,
|
||||
module_abspath=module_abspath,
|
||||
preexisting_objects=[],
|
||||
contextual_functions=code_context.contextual_dunder_methods,
|
||||
preexisting_objects=code_context.preexisting_objects,
|
||||
project_root_path=self.args.project_root,
|
||||
)
|
||||
return did_update
|
||||
|
|
@ -733,25 +731,19 @@ class Optimizer:
|
|||
project_root,
|
||||
helper_functions,
|
||||
)
|
||||
preexisting_objects = find_preexisting_objects(code_to_optimize_with_helpers)
|
||||
contextual_dunder_methods.update(helper_dunder_methods)
|
||||
|
||||
# Will eventually refactor to use this function instead of the above
|
||||
try:
|
||||
read_writable_code, read_only_context_code = code_context_extractor.get_code_optimization_context(
|
||||
function_to_optimize, project_root
|
||||
)
|
||||
new_code_ctx = code_context_extractor.get_code_optimization_context(function_to_optimize, project_root)
|
||||
except ValueError as e:
|
||||
return Failure(str(e))
|
||||
|
||||
return Success(
|
||||
CodeOptimizationContext(
|
||||
code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,
|
||||
read_writable_code=read_writable_code,
|
||||
read_only_context_code=read_only_context_code,
|
||||
contextual_dunder_methods=contextual_dunder_methods,
|
||||
helper_functions=helper_functions,
|
||||
preexisting_objects=preexisting_objects,
|
||||
read_writable_code=new_code_ctx.read_writable_code,
|
||||
read_only_context_code=new_code_ctx.read_only_context_code,
|
||||
helper_functions=new_code_ctx.helper_functions, # only functions that are read writable
|
||||
preexisting_objects=new_code_ctx.preexisting_objects,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -72,9 +72,8 @@ def test_code_replacement10() -> None:
|
|||
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize=func_top_optimize, project_root_path=file_path.parent
|
||||
)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
|
||||
expected_read_write_context = """
|
||||
from __future__ import annotations
|
||||
|
|
@ -119,10 +118,8 @@ def test_class_method_dependencies() -> None:
|
|||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize, file_path.parent.resolve()
|
||||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, file_path.parent.resolve())
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -181,9 +178,8 @@ def test_bubble_sort_helper() -> None:
|
|||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize, Path(__file__).resolve().parent.parent
|
||||
)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, Path(__file__).resolve().parent.parent)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
|
||||
expected_read_write_context = """
|
||||
import math
|
||||
|
|
@ -397,9 +393,8 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize, opt.args.project_root
|
||||
)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
||||
|
||||
|
|
@ -599,9 +594,8 @@ class HelperClass:
|
|||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize, opt.args.project_root
|
||||
)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
class MyClass:
|
||||
def target_method(self):
|
||||
|
|
@ -680,9 +674,8 @@ class HelperClass:
|
|||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize, opt.args.project_root
|
||||
)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
|
||||
expected_read_write_context = """
|
||||
class MyClass:
|
||||
|
|
@ -759,9 +752,8 @@ class HelperClass:
|
|||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize, opt.args.project_root
|
||||
)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
|
||||
expected_read_write_context = """
|
||||
class MyClass:
|
||||
|
|
@ -826,9 +818,7 @@ class HelperClass:
|
|||
)
|
||||
# In this scenario, the read-writable code is too long, so we abort.
|
||||
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize, opt.args.project_root
|
||||
)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
|
||||
|
||||
def test_repo_helper() -> None:
|
||||
|
|
@ -843,7 +833,8 @@ def test_repo_helper() -> None:
|
|||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(function_to_optimize, project_root)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
import requests
|
||||
from globals import API_URL
|
||||
|
|
@ -920,7 +911,8 @@ def test_repo_helper_of_helper() -> None:
|
|||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(function_to_optimize, project_root)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
from transform_utils import DataTransformer
|
||||
import requests
|
||||
|
|
@ -1006,7 +998,8 @@ def test_repo_helper_of_helper_same_class() -> None:
|
|||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(function_to_optimize, project_root)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
|
|
@ -1073,7 +1066,8 @@ def test_repo_helper_of_helper_same_file() -> None:
|
|||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(function_to_optimize, project_root)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
|
|
@ -1137,7 +1131,8 @@ def test_repo_helper_all_same_file() -> None:
|
|||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(function_to_optimize, project_root)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
class DataTransformer:
|
||||
|
||||
|
|
@ -1181,7 +1176,8 @@ def test_repo_helper_circular_dependency() -> None:
|
|||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(function_to_optimize, project_root)
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
from transform_utils import DataTransformer
|
||||
from code_to_optimize.code_directories.retriever.utils import DataProcessor
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from argparse import Namespace
|
|||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.code_utils.code_extractor import delete___future___aliased_imports
|
||||
from codeflash.code_utils.code_extractor import delete___future___aliased_imports, find_preexisting_objects
|
||||
from codeflash.code_utils.code_replacer import (
|
||||
is_zero_diff,
|
||||
replace_functions_and_add_imports,
|
||||
|
|
@ -74,10 +74,7 @@ print("Hello world")
|
|||
"""
|
||||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
("new_function", [FunctionParent(name="NewClass", type="ClassDef")])
|
||||
]
|
||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
|
|
@ -85,7 +82,6 @@ print("Hello world")
|
|||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
|
@ -133,12 +129,14 @@ class NewClass:
|
|||
def totally_new_function(value):
|
||||
return value
|
||||
|
||||
def other_function(st):
|
||||
return(st * 2)
|
||||
|
||||
print("Hello world")
|
||||
"""
|
||||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("new_function", []), ("other_function", [])]
|
||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
|
|
@ -146,7 +144,6 @@ print("Hello world")
|
|||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
|
@ -194,12 +191,14 @@ def yet_another_function(values):
|
|||
def other_function(st):
|
||||
return(st * 2)
|
||||
|
||||
def totally_new_function(value):
|
||||
return value
|
||||
|
||||
print("Salut monde")
|
||||
"""
|
||||
|
||||
function_names: list[str] = ["module.other_function"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
function_names: list[str] = ["other_function"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -207,7 +206,6 @@ print("Salut monde")
|
|||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
|
@ -258,12 +256,14 @@ def yet_another_function(values):
|
|||
def other_function(st):
|
||||
return(st * 2)
|
||||
|
||||
def totally_new_function(value):
|
||||
return value
|
||||
|
||||
print("Salut monde")
|
||||
"""
|
||||
|
||||
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
function_names: list[str] = ["yet_another_function", "other_function"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -271,7 +271,6 @@ print("Salut monde")
|
|||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
|
@ -318,8 +317,7 @@ def supersort(doink):
|
|||
"""
|
||||
|
||||
function_names: list[str] = ["sorter_deps"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("sorter_deps", [])]
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -327,7 +325,6 @@ def supersort(doink):
|
|||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
|
@ -396,14 +393,14 @@ def blab(st):
|
|||
|
||||
print("Not cool")
|
||||
"""
|
||||
preexisting_objects = find_preexisting_objects(original_code_main) + find_preexisting_objects(original_code_helper)
|
||||
new_main_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code_main,
|
||||
function_names=["other_function"],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=[("other_function", []), ("yet_another_function", []), ("blob", [])],
|
||||
contextual_functions=set(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_main_code == expected_main
|
||||
|
|
@ -414,8 +411,7 @@ print("Not cool")
|
|||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=[],
|
||||
contextual_functions=set(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_helper_code == expected_helper
|
||||
|
|
@ -602,14 +598,8 @@ class CacheConfig(BaseConfig):
|
|||
)
|
||||
"""
|
||||
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
|
||||
parents = [FunctionParent(name="CacheConfig", type="ClassDef")]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("from_config", parents)]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
|
||||
contextual_functions: set[tuple[str, str]] = {
|
||||
("CacheSimilarityEvalConfig", "__init__"),
|
||||
("CacheConfig", "__init__"),
|
||||
("CacheInitConfig", "__init__"),
|
||||
}
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -617,7 +607,6 @@ class CacheConfig(BaseConfig):
|
|||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
|
@ -681,10 +670,7 @@ def test_test_libcst_code_replacement8() -> None:
|
|||
return np.sum(a != b) / a.size
|
||||
'''
|
||||
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
("_hamming_distance", [FunctionParent("_EmbeddingDistanceChainMixin", "ClassDef")])
|
||||
]
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -692,7 +678,6 @@ def test_test_libcst_code_replacement8() -> None:
|
|||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
|
@ -738,10 +723,8 @@ def totally_new_function(value: Optional[str]):
|
|||
|
||||
print("Hello world")
|
||||
"""
|
||||
parents = [FunctionParent(name="NewClass", type="ClassDef")]
|
||||
function_name: str = "NewClass.__init__"
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("__init__", parents), ("__call__", parents)]
|
||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__"), ("NewClass", "__call__")}
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
|
|
@ -749,7 +732,6 @@ print("Hello world")
|
|||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
|
@ -847,13 +829,11 @@ def test_code_replacement11() -> None:
|
|||
function_name: str = "Fu.foo"
|
||||
parents = [FunctionParent("Fu", "ClassDef")]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [("foo", parents), ("real_bar", parents)]
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
source_code=original_code,
|
||||
original_function_names=[function_name],
|
||||
optimized_code=optim_code,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
)
|
||||
assert new_code == expected_code
|
||||
|
||||
|
|
@ -888,13 +868,11 @@ def test_code_replacement12() -> None:
|
|||
'''
|
||||
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
source_code=original_code,
|
||||
original_function_names=["Fu.real_bar"],
|
||||
optimized_code=optim_code,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
)
|
||||
assert new_code == expected_code
|
||||
|
||||
|
|
@ -924,9 +902,8 @@ def test_test_libcst_code_replacement13() -> None:
|
|||
return self.name
|
||||
"""
|
||||
|
||||
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
|
||||
function_names: list[str] = ["yet_another_function", "other_function"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
|
|
@ -934,7 +911,6 @@ def test_test_libcst_code_replacement13() -> None:
|
|||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == original_code
|
||||
|
|
@ -1100,31 +1076,7 @@ class TestResults(BaseModel):
|
|||
report[test_result.test_type][key] += 1
|
||||
return report"""
|
||||
|
||||
preexisting_objects = [
|
||||
("__contains__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("__len__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("__bool__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("__eq__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("__delitem__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("__iter__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("__setitem__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("__getitem__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("get_test_pass_fail_report_by_type", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("TestType", []),
|
||||
("TestResults", []),
|
||||
("to_name", [FunctionParent(name="TestType", type="ClassDef")]),
|
||||
]
|
||||
|
||||
contextual_functions = {
|
||||
("TestResults", "__bool__"),
|
||||
("TestResults", "__contains__"),
|
||||
("TestResults", "__delitem__"),
|
||||
("TestResults", "__eq__"),
|
||||
("TestResults", "__getitem__"),
|
||||
("TestResults", "__iter__"),
|
||||
("TestResults", "__len__"),
|
||||
("TestResults", "__setitem__"),
|
||||
}
|
||||
preexisting_objects = find_preexisting_objects(original_code)
|
||||
|
||||
helper_functions = [
|
||||
FakeFunctionSource(
|
||||
|
|
@ -1146,7 +1098,6 @@ class TestResults(BaseModel):
|
|||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=Path(__file__).parent.resolve(),
|
||||
)
|
||||
|
||||
|
|
@ -1162,7 +1113,6 @@ class TestResults(BaseModel):
|
|||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=module_abspath,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=Path(__file__).parent.resolve(),
|
||||
)
|
||||
|
||||
|
|
@ -1343,13 +1293,8 @@ def cosine_similarity_top_k(
|
|||
|
||||
return ret_idxs, scores
|
||||
'''
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
("cosine_similarity_top_k", []),
|
||||
("Matrix", []),
|
||||
("cosine_similarity", []),
|
||||
]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
helper_functions = [
|
||||
FakeFunctionSource(
|
||||
file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(),
|
||||
|
|
@ -1376,7 +1321,6 @@ def cosine_similarity_top_k(
|
|||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=(Path(__file__).parent / "code_to_optimize").resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=Path(__file__).parent.parent.resolve(),
|
||||
)
|
||||
assert (
|
||||
|
|
@ -1436,7 +1380,6 @@ def cosine_similarity_top_k(
|
|||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=module_abspath,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=Path(__file__).parent.parent.resolve(),
|
||||
)
|
||||
|
||||
|
|
@ -1599,3 +1542,105 @@ def functionA():
|
|||
return np.array([1, 2, 3])
|
||||
'''
|
||||
assert is_zero_diff(original_code, optim_code_e)
|
||||
|
||||
|
||||
def test_nested_class() -> None:
|
||||
optim_code = """import libcst as cst
|
||||
from typing import Optional
|
||||
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = str(name)
|
||||
def __call__(self, value):
|
||||
return self.name
|
||||
def new_function2(value):
|
||||
return cst.ensure_type(value, int)
|
||||
|
||||
class NestedClass:
|
||||
def nested_function(self):
|
||||
return "I am nested and modified"
|
||||
"""
|
||||
|
||||
original_code = """class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def __call__(self, value):
|
||||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
|
||||
class NestedClass:
|
||||
def nested_function(self):
|
||||
return "I am nested"
|
||||
|
||||
print("Hello world")
|
||||
"""
|
||||
expected = """import libcst as cst
|
||||
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def __call__(self, value):
|
||||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return cst.ensure_type(value, int)
|
||||
|
||||
class NestedClass:
|
||||
def nested_function(self):
|
||||
return "I am nested"
|
||||
|
||||
print("Hello world")
|
||||
"""
|
||||
|
||||
function_names: list[str] = [
|
||||
"NewClass.new_function2",
|
||||
"NestedClass.nested_function",
|
||||
] # Nested classes should be ignored, even if provided as target
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
||||
def test_modify_back_to_original() -> None:
|
||||
optim_code = """class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def __call__(self, value):
|
||||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
|
||||
print("Hello world")
|
||||
"""
|
||||
|
||||
original_code = """class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def __call__(self, value):
|
||||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
|
||||
print("Hello world")
|
||||
"""
|
||||
|
||||
function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = find_preexisting_objects(original_code)
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == original_code
|
||||
|
|
|
|||
|
|
@ -61,8 +61,10 @@ class A:
|
|||
def nested_function(self):
|
||||
def nested():
|
||||
return global_dependency_3(1)
|
||||
|
||||
return nested() + self.add_two(3)
|
||||
|
||||
|
||||
class B:
|
||||
def calculate_something_2(self, num):
|
||||
return num + 1
|
||||
|
|
@ -217,7 +219,6 @@ def test_class_method_dependencies() -> None:
|
|||
code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.Graph.topologicalSortUtil"
|
||||
)
|
||||
assert code_context.helper_functions[0].qualified_name == "Graph.topologicalSortUtil"
|
||||
assert code_context.contextual_dunder_methods == {("Graph", "__init__")}
|
||||
assert (
|
||||
code_context.code_to_optimize_with_helpers
|
||||
== """from collections import defaultdict
|
||||
|
|
@ -303,8 +304,9 @@ def test_recursive_function_context() -> None:
|
|||
if not is_successful(ctx_result):
|
||||
pytest.fail()
|
||||
code_context = ctx_result.unwrap()
|
||||
assert len(code_context.helper_functions) == 1
|
||||
assert len(code_context.helper_functions) == 2
|
||||
assert code_context.helper_functions[0].fully_qualified_name == "test_function_dependencies.C.calculate_something_3"
|
||||
assert code_context.helper_functions[1].fully_qualified_name == "test_function_dependencies.C.recursive"
|
||||
assert (
|
||||
code_context.code_to_optimize_with_helpers
|
||||
== """class C:
|
||||
|
|
@ -360,6 +362,7 @@ def test_method_in_method_list_comprehension() -> None:
|
|||
assert len(helper_functions) == 1
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.A.add_two"
|
||||
|
||||
|
||||
def test_nested_method() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
|
|
|
|||
|
|
@ -237,9 +237,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
if not is_successful(ctx_result):
|
||||
pytest.fail()
|
||||
code_context = ctx_result.unwrap()
|
||||
assert code_context.helper_functions[0].qualified_name == "_R"
|
||||
assert code_context.helper_functions[1].qualified_name == "AbstractCacheBackend.get_cache_or_call"
|
||||
assert len(code_context.contextual_dunder_methods) == 2
|
||||
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
|
||||
|
||||
assert (
|
||||
code_context.code_to_optimize_with_helpers
|
||||
|
|
|
|||
Loading…
Reference in a new issue