mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
draft PR for init caching. no instrumentation checks implemented yet
This commit is contained in:
parent
d8ac58c5bb
commit
8de9cebe90
10 changed files with 814 additions and 182 deletions
|
|
@ -9,6 +9,7 @@ import libcst as cst
|
|||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
|
||||
from codeflash.code_utils.code_utils import cst_to_code, get_only_code_content
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -51,10 +52,13 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
self.new_functions: list[cst.FunctionDef] = []
|
||||
self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
|
||||
self.current_class = None
|
||||
self.modified_init_functions: dict[str, cst.FunctionDef] = {}
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
|
||||
if (self.current_class, node.name.value) in self.function_names:
|
||||
self.modified_functions[(self.current_class, node.name.value)] = node
|
||||
elif self.current_class and node.name.value == "__init__":
|
||||
self.modified_init_functions[self.current_class] = node
|
||||
elif (
|
||||
self.preexisting_objects
|
||||
and (node.name.value, []) not in self.preexisting_objects
|
||||
|
|
@ -76,6 +80,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
and (child_node.name.value, parents) not in self.preexisting_objects
|
||||
):
|
||||
self.new_class_functions[node.name.value].append(child_node)
|
||||
|
||||
return True
|
||||
|
||||
def leave_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
|
|
@ -89,11 +94,15 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
modified_functions: dict[tuple[str | None, str], cst.FunctionDef] = None,
|
||||
new_functions: list[cst.FunctionDef] = None,
|
||||
new_class_functions: dict[str, list[cst.FunctionDef]] = None,
|
||||
modified_init_functions: dict[str, cst.FunctionDef] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.modified_functions = modified_functions if modified_functions is not None else {}
|
||||
self.new_functions = new_functions if new_functions is not None else []
|
||||
self.new_class_functions = new_class_functions if new_class_functions is not None else defaultdict(list)
|
||||
self.modified_init_functions: dict[str, cst.FunctionDef] = (
|
||||
modified_init_functions if modified_init_functions is not None else {}
|
||||
)
|
||||
self.current_class = None
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
|
||||
|
|
@ -102,7 +111,15 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
if (self.current_class, original_node.name.value) in self.modified_functions:
|
||||
node = self.modified_functions[(self.current_class, original_node.name.value)]
|
||||
if get_only_code_content(cst_to_code(original_node)) == get_only_code_content(cst_to_code(node)):
|
||||
return original_node # Code was unchanged, so don't modify docstrings / comments
|
||||
return updated_node.with_changes(body=node.body, decorators=node.decorators)
|
||||
if original_node.name.value == "__init__" and self.current_class in self.modified_init_functions:
|
||||
if get_only_code_content(cst_to_code(original_node)) == get_only_code_content(
|
||||
cst_to_code(self.modified_init_functions[self.current_class])
|
||||
):
|
||||
return original_node # Code was unchanged, so don't modify docstrings / comments
|
||||
return merge_init_functions(updated_node, self.modified_init_functions[self.current_class])
|
||||
|
||||
return updated_node
|
||||
|
||||
|
|
@ -145,6 +162,97 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
return node
|
||||
|
||||
|
||||
class AttributeCollector(cst.CSTVisitor):
|
||||
"""Collects all self.attribute mentions in a CST."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attributes: set[str] = set()
|
||||
|
||||
def visit_Attribute(self, node: cst.Attribute) -> bool:
|
||||
"""Record any self.attribute access."""
|
||||
if isinstance(node.value, cst.Name) and node.value.value == "self":
|
||||
self.attributes.add(node.attr.value)
|
||||
return True
|
||||
|
||||
|
||||
class AssignmentCollector(cst.CSTVisitor):
|
||||
"""Collects attributes being assigned to in a CST."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.assigned_attrs: set[str] = set()
|
||||
|
||||
def visit_Assign(self, node: cst.Assign) -> bool:
|
||||
"""Check regular assignments like self.x = ..."""
|
||||
for target in node.targets:
|
||||
if (
|
||||
isinstance(target.target, cst.Attribute)
|
||||
and isinstance(target.target.value, cst.Name)
|
||||
and target.target.value.value == "self"
|
||||
):
|
||||
self.assigned_attrs.add(target.target.attr.value)
|
||||
return True
|
||||
|
||||
def visit_AnnAssign(self, node: cst.AnnAssign) -> bool:
|
||||
"""Check annotated assignments like self.x: str = ..."""
|
||||
if (
|
||||
isinstance(node.target, cst.Attribute)
|
||||
and isinstance(node.target.value, cst.Name)
|
||||
and node.target.value.value == "self"
|
||||
):
|
||||
self.assigned_attrs.add(node.target.attr.value)
|
||||
return True
|
||||
|
||||
def visit_AugAssign(self, node: cst.AugAssign) -> bool:
|
||||
"""Check augmented assignments like self.x += ..."""
|
||||
if (
|
||||
isinstance(node.target, cst.Attribute)
|
||||
and isinstance(node.target.value, cst.Name)
|
||||
and node.target.value.value == "self"
|
||||
):
|
||||
self.assigned_attrs.add(node.target.attr.value)
|
||||
return True
|
||||
|
||||
|
||||
def merge_init_functions(original_init: cst.FunctionDef, new_init: cst.FunctionDef) -> cst.FunctionDef:
|
||||
"""Merges two __init__ function definitions. Collects all self.attribute mentions
|
||||
from the original init, then filters out statements from the new init that
|
||||
assign to those attributes (but allows reading them).
|
||||
|
||||
Args:
|
||||
original_init: The original __init__ function to preserve
|
||||
new_init: The new __init__ function whose body will be filtered and appended
|
||||
|
||||
Returns:
|
||||
A merged FunctionDef
|
||||
|
||||
"""
|
||||
# Collect all self.attribute mentions from original init
|
||||
collector = AttributeCollector()
|
||||
original_init.visit(collector)
|
||||
existing_attrs = collector.attributes
|
||||
# Get set of existing statements as strings. # This should just be in terms of code, not comments?
|
||||
original_stmts = {cst.Module([stmt]).code for stmt in original_init.body.body}
|
||||
# Filter new init body statements
|
||||
filtered_body = []
|
||||
for stmt in new_init.body.body:
|
||||
if cst.Module([stmt]).code in original_stmts:
|
||||
continue
|
||||
# Check for assignments to existing attributes
|
||||
assign_collector = AssignmentCollector()
|
||||
stmt.visit(assign_collector)
|
||||
|
||||
# Keep statement if it doesn't assign to any existing attributes
|
||||
if not assign_collector.assigned_attrs.intersection(existing_attrs):
|
||||
filtered_body.append(stmt)
|
||||
|
||||
# Merge bodies using with_changes
|
||||
return original_init.with_changes(
|
||||
body=original_init.body.with_changes(body=original_init.body.body + tuple(filtered_body))
|
||||
)
|
||||
|
||||
|
||||
def replace_functions_in_file(
|
||||
source_code: str,
|
||||
original_function_names: list[str],
|
||||
|
|
@ -173,6 +281,7 @@ def replace_functions_in_file(
|
|||
modified_functions=visitor.modified_functions,
|
||||
new_functions=visitor.new_functions,
|
||||
new_class_functions=visitor.new_class_functions,
|
||||
modified_init_functions=visitor.modified_init_functions,
|
||||
)
|
||||
original_module = cst.parse_module(source_code)
|
||||
modified_tree = original_module.visit(transformer)
|
||||
|
|
@ -191,7 +300,7 @@ def replace_functions_and_add_imports(
|
|||
return add_needed_imports_from_module(
|
||||
optimized_code,
|
||||
replace_functions_in_file(source_code, function_names, optimized_code, preexisting_objects),
|
||||
file_path_of_module_with_function_to_optimize,
|
||||
module_abspath,
|
||||
module_abspath,
|
||||
project_root_path,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,9 +7,37 @@ from functools import lru_cache
|
|||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
||||
|
||||
def cst_to_code(node: cst.CSTNode) -> str:
|
||||
return cst.Module([node]).code.strip()
|
||||
|
||||
|
||||
def get_only_code_content(code: str) -> str:
|
||||
"""Extract just the code content from code, ignoring comments and docstrings.
|
||||
|
||||
Args:
|
||||
code: Source code as a string
|
||||
Returns:
|
||||
String of code with comments and docstrings removed
|
||||
|
||||
"""
|
||||
# Parse into AST - this automatically strips comments
|
||||
tree = ast.parse(code)
|
||||
|
||||
# Remove docstrings from function
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FunctionDef) and (ast.get_docstring(node)):
|
||||
# If first element is docstring, remove it
|
||||
node.body = node.body[1:]
|
||||
|
||||
# Unparse back to source code for comparison
|
||||
return ast.unparse(tree)
|
||||
|
||||
|
||||
def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
|
||||
if not full_qualified_name:
|
||||
raise ValueError("full_qualified_name cannot be empty")
|
||||
|
|
@ -80,6 +108,7 @@ def get_all_function_names(code: str) -> tuple[bool, list[str]]:
|
|||
def get_run_tmp_file(file_path: Path) -> Path:
|
||||
if not hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
|
||||
logger.info(f"Created new temp directory for codeflash: {Path(get_run_tmp_file.tmpdir.name)!s}")
|
||||
return Path(get_run_tmp_file.tmpdir.name) / file_path
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -26,9 +26,14 @@ def get_code_optimization_context(
|
|||
helpers_of_fto, helpers_of_fto_fqn, helpers_of_fto_obj_list = get_file_path_to_helper_functions_dict(
|
||||
{function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path
|
||||
)
|
||||
|
||||
helpers_of_helpers, helpers_of_helpers_fqn, _ = get_file_path_to_helper_functions_dict(
|
||||
helpers_of_fto, project_root_path
|
||||
)
|
||||
print("helpers_of_fto")
|
||||
print(helpers_of_fto)
|
||||
print("helpers_of_helpers")
|
||||
print(helpers_of_helpers)
|
||||
# Add function to optimize
|
||||
helpers_of_fto[function_to_optimize.file_path].add(function_to_optimize.qualified_name)
|
||||
helpers_of_fto_fqn[function_to_optimize.file_path].add(
|
||||
|
|
@ -45,13 +50,15 @@ def get_code_optimization_context(
|
|||
project_root_path,
|
||||
remove_docstrings=False,
|
||||
)
|
||||
testgen_context_code = get_all_testgen_context(helpers_of_fto, helpers_of_fto_fqn, project_root_path)
|
||||
|
||||
# Handle token limits
|
||||
tokenizer = tiktoken.encoding_for_model("gpt-4o")
|
||||
final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
|
||||
if final_read_writable_tokens > token_limit:
|
||||
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
|
||||
|
||||
# if len(tokenizer.encode(testgen_context_code.code)) > token_limit:
|
||||
# raise ValueError("Testgen context has exceeded token limit, cannot proceed")
|
||||
# Setup preexisting objects for code replacer TODO: should remove duplicates
|
||||
preexisting_objects = list(
|
||||
chain(
|
||||
|
|
@ -68,6 +75,7 @@ def get_code_optimization_context(
|
|||
read_only_context_code=read_only_code_markdown.markdown,
|
||||
helper_functions=helpers_of_fto_obj_list,
|
||||
preexisting_objects=preexisting_objects,
|
||||
testgen_context_code=testgen_context_code.code,
|
||||
)
|
||||
|
||||
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
|
||||
|
|
@ -90,6 +98,7 @@ def get_code_optimization_context(
|
|||
read_only_context_code=read_only_code_no_docstring_markdown.markdown,
|
||||
helper_functions=helpers_of_fto_obj_list,
|
||||
preexisting_objects=preexisting_objects,
|
||||
testgen_context_code=testgen_context_code.code,
|
||||
)
|
||||
|
||||
logger.debug("Code context has exceeded token limit, removing read-only code")
|
||||
|
|
@ -99,9 +108,40 @@ def get_code_optimization_context(
|
|||
read_only_context_code="",
|
||||
helper_functions=helpers_of_fto_obj_list,
|
||||
preexisting_objects=preexisting_objects,
|
||||
testgen_context_code=testgen_context_code.code,
|
||||
)
|
||||
|
||||
|
||||
def get_all_testgen_context(
|
||||
helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], project_root_path: Path
|
||||
) -> CodeString:
|
||||
final_testgen_context = ""
|
||||
# Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files
|
||||
for file_path, qualified_function_names in helpers_of_fto.items():
|
||||
try:
|
||||
original_code = file_path.read_text("utf8")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
try:
|
||||
testgen_context_code = get_testgen_context(original_code, qualified_function_names)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
||||
if testgen_context_code:
|
||||
final_testgen_context += f"\n{testgen_context_code}"
|
||||
final_testgen_context = add_needed_imports_from_module(
|
||||
src_module_code=original_code,
|
||||
dst_module_code=final_testgen_context,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions_fqn=helpers_of_fto_fqn[file_path],
|
||||
)
|
||||
return CodeString(code=final_testgen_context)
|
||||
|
||||
|
||||
def get_all_read_writable_code(
|
||||
helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], project_root_path: Path
|
||||
) -> CodeString:
|
||||
|
|
@ -167,19 +207,18 @@ def get_all_read_only_code_context(
|
|||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
||||
read_only_code_with_imports = CodeString(
|
||||
code=add_needed_imports_from_module(
|
||||
src_module_code=original_code,
|
||||
dst_module_code=read_only_code,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions_fqn=helpers_of_fto_fqn[file_path] | helpers_of_helpers_fqn[file_path],
|
||||
),
|
||||
file_path=file_path.relative_to(project_root_path),
|
||||
)
|
||||
if read_only_code_with_imports.code:
|
||||
if read_only_code.strip():
|
||||
read_only_code_with_imports = CodeString(
|
||||
code=add_needed_imports_from_module(
|
||||
src_module_code=original_code,
|
||||
dst_module_code=read_only_code,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions_fqn=helpers_of_fto_fqn[file_path] | helpers_of_helpers_fqn[file_path],
|
||||
),
|
||||
file_path=file_path.relative_to(project_root_path),
|
||||
)
|
||||
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
|
||||
|
||||
# Extract code from file paths containing helpers of helpers
|
||||
|
|
@ -197,18 +236,18 @@ def get_all_read_only_code_context(
|
|||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
||||
read_only_code_with_imports = CodeString(
|
||||
code=add_needed_imports_from_module(
|
||||
src_module_code=original_code,
|
||||
dst_module_code=read_only_code,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions_fqn=helpers_of_helpers_no_overlap_fqn[file_path],
|
||||
),
|
||||
file_path=file_path.relative_to(project_root_path),
|
||||
)
|
||||
if read_only_code_with_imports.code:
|
||||
if read_only_code.strip():
|
||||
read_only_code_with_imports = CodeString(
|
||||
code=add_needed_imports_from_module(
|
||||
src_module_code=original_code,
|
||||
dst_module_code=read_only_code,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions_fqn=helpers_of_helpers_no_overlap_fqn[file_path],
|
||||
),
|
||||
file_path=file_path.relative_to(project_root_path),
|
||||
)
|
||||
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
|
||||
return read_only_code_markdown
|
||||
|
||||
|
|
@ -327,9 +366,10 @@ def prune_cst_for_read_writable_code(
|
|||
if qualified_name in target_functions:
|
||||
new_body.append(stmt)
|
||||
found_target = True
|
||||
|
||||
elif stmt.name.value == "__init__":
|
||||
new_body.append(stmt) # enable __init__ optimizations
|
||||
# If no target functions found, remove the class entirely
|
||||
if not new_body:
|
||||
if not new_body or not found_target:
|
||||
return None, False
|
||||
|
||||
return node.with_changes(body=cst.IndentedBlock(body=new_body)), found_target
|
||||
|
|
@ -408,7 +448,7 @@ def prune_cst_for_read_only_code(
|
|||
if qualified_name in target_functions:
|
||||
return None, True
|
||||
# Keep only dunder methods
|
||||
if is_dunder_method(node.name.value):
|
||||
if is_dunder_method(node.name.value) and node.name.value != "__init__":
|
||||
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
new_body = remove_docstring_from_body(node.body)
|
||||
return node.with_changes(body=new_body), False
|
||||
|
|
@ -501,3 +541,109 @@ def get_read_only_code(
|
|||
if filtered_node and isinstance(filtered_node, cst.Module):
|
||||
return str(filtered_node.code)
|
||||
return ""
|
||||
|
||||
|
||||
def prune_cst_for_testgen_context(
|
||||
node: cst.CSTNode, target_functions: set[str], prefix: str = "", remove_docstrings: bool = False
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Recursively filter the node for read-only context:
|
||||
|
||||
Returns:
|
||||
(filtered_node, found_target):
|
||||
filtered_node: The modified CST node or None if it should be removed.
|
||||
found_target: True if a target function was found in this node's subtree.
|
||||
|
||||
"""
|
||||
if isinstance(node, (cst.Import, cst.ImportFrom)):
|
||||
return None, False
|
||||
|
||||
if isinstance(node, cst.FunctionDef):
|
||||
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||
if qualified_name in target_functions:
|
||||
return node, True
|
||||
# Keep only dunder methods
|
||||
if is_dunder_method(node.name.value):
|
||||
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
new_body = remove_docstring_from_body(node.body)
|
||||
return node.with_changes(body=new_body), False
|
||||
return node, False
|
||||
return None, False
|
||||
|
||||
if isinstance(node, cst.ClassDef):
|
||||
# Do not recurse into nested classes
|
||||
if prefix:
|
||||
return None, False
|
||||
# Assuming always an IndentedBlock
|
||||
if not isinstance(node.body, cst.IndentedBlock):
|
||||
raise ValueError("ClassDef body is not an IndentedBlock")
|
||||
|
||||
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||
|
||||
# First pass: detect if there is a target function in the class
|
||||
found_in_class = False
|
||||
new_class_body: list[CSTNode] = []
|
||||
for stmt in node.body.body:
|
||||
filtered, found_target = prune_cst_for_testgen_context(
|
||||
stmt, target_functions, class_prefix, remove_docstrings=remove_docstrings
|
||||
)
|
||||
found_in_class |= found_target
|
||||
if filtered:
|
||||
new_class_body.append(filtered)
|
||||
|
||||
if not found_in_class:
|
||||
return None, False
|
||||
|
||||
if remove_docstrings:
|
||||
return node.with_changes(
|
||||
body=remove_docstring_from_body(node.body.with_changes(body=new_class_body))
|
||||
) if new_class_body else None, True
|
||||
return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True
|
||||
|
||||
# For other nodes, keep the node and recursively filter children
|
||||
section_names = get_section_names(node)
|
||||
if not section_names:
|
||||
return node, False
|
||||
|
||||
updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {}
|
||||
found_any_target = False
|
||||
|
||||
for section in section_names:
|
||||
original_content = getattr(node, section, None)
|
||||
if isinstance(original_content, (list, tuple)):
|
||||
new_children = []
|
||||
section_found_target = False
|
||||
for child in original_content:
|
||||
filtered, found_target = prune_cst_for_testgen_context(
|
||||
child, target_functions, prefix, remove_docstrings=remove_docstrings
|
||||
)
|
||||
if filtered:
|
||||
new_children.append(filtered)
|
||||
section_found_target |= found_target
|
||||
|
||||
if section_found_target or new_children:
|
||||
found_any_target |= section_found_target
|
||||
updates[section] = new_children
|
||||
elif original_content is not None:
|
||||
filtered, found_target = prune_cst_for_testgen_context(
|
||||
original_content, target_functions, prefix, remove_docstrings=remove_docstrings
|
||||
)
|
||||
found_any_target |= found_target
|
||||
if filtered:
|
||||
updates[section] = filtered
|
||||
if updates:
|
||||
return (node.with_changes(**updates), found_any_target)
|
||||
|
||||
return None, False
|
||||
|
||||
|
||||
def get_testgen_context(code: str, target_functions: set[str], remove_docstrings: bool = False) -> str:
|
||||
"""Creates testgen_context. Similar to get_read_only_code, except the target functions are included."""
|
||||
module = cst.parse_module(code)
|
||||
filtered_node, found_target = prune_cst_for_testgen_context(
|
||||
module, target_functions, remove_docstrings=remove_docstrings
|
||||
)
|
||||
if not found_target:
|
||||
raise ValueError("No target functions found in the provided code")
|
||||
if filtered_node and isinstance(filtered_node, cst.Module):
|
||||
return str(filtered_node.code)
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -71,6 +71,78 @@ from codeflash.code_utils.code_utils import path_belongs_to_site_packages
|
|||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
|
||||
|
||||
def belongs_to_class(name: Name, class_name: str) -> bool:
|
||||
"""Check if the given name belongs to the specified class."""
|
||||
if name.full_name and name.full_name.startswith(f"{name.module_name}.{class_name}."):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def belongs_to_function(name: Name, function_name: str) -> bool:
|
||||
"""Check if the given name belongs to the specified function"""
|
||||
if name.full_name and name.full_name.startswith(name.module_name):
|
||||
subname: str = name.full_name.replace(name.module_name, "", 1)
|
||||
else:
|
||||
return False
|
||||
# The name is defined inside the function or is the function itself
|
||||
return f".{function_name}." in subname or f".{function_name}" == subname
|
||||
|
||||
|
||||
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
|
||||
class Source:
|
||||
full_name: str
|
||||
definition: Name
|
||||
source_code: str
|
||||
'''
|
||||
|
||||
dst_module = '''def belongs_to_function(name: Name, function_name: str) -> bool:
|
||||
"""Check if the given name belongs to the specified function"""
|
||||
if name.full_name and name.full_name.startswith(name.module_name):
|
||||
subname: str = name.full_name.replace(name.module_name, "", 1)
|
||||
else:
|
||||
return False
|
||||
# The name is defined inside the function or is the function itself
|
||||
return f".{function_name}." in subname or f".{function_name}" == subname
|
||||
'''
|
||||
|
||||
expected = '''from jedi.api.classes import Name
|
||||
|
||||
def belongs_to_function(name: Name, function_name: str) -> bool:
|
||||
"""Check if the given name belongs to the specified function"""
|
||||
if name.full_name and name.full_name.startswith(name.module_name):
|
||||
subname: str = name.full_name.replace(name.module_name, "", 1)
|
||||
else:
|
||||
return False
|
||||
# The name is defined inside the function or is the function itself
|
||||
return f".{function_name}." in subname or f".{function_name}" == subname
|
||||
'''
|
||||
src_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
|
||||
dst_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
|
||||
project_root = Path("/home/roger/repos/codeflash")
|
||||
new_module = add_needed_imports_from_module(src_module, dst_module, src_path, dst_path, project_root)
|
||||
assert new_module == expected
|
||||
|
||||
|
||||
def test_add_needed_imports_from_module_with_if() -> None:
|
||||
src_module = '''import ast
|
||||
import logging
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeGuard
|
||||
else:
|
||||
from typing_extensions import TypeGuard
|
||||
import jedi
|
||||
import tiktoken
|
||||
from jedi.api.classes import Name
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.code_utils.code_extractor import get_code, get_code_no_skeleton
|
||||
from codeflash.code_utils.code_utils import path_belongs_to_site_packages
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
|
||||
|
||||
def belongs_to_class(name: Name, class_name: str) -> bool:
|
||||
"""Check if the given name belongs to the specified class."""
|
||||
if name.full_name and name.full_name.startswith(f"{name.module_name}.{class_name}."):
|
||||
|
|
|
|||
|
|
@ -73,38 +73,53 @@ def test_code_replacement10() -> None:
|
|||
)
|
||||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
read_write_context, read_only_context, testgen_context = (
|
||||
code_ctx.read_writable_code,
|
||||
code_ctx.read_only_context_code,
|
||||
code_ctx.testgen_context_code,
|
||||
)
|
||||
|
||||
expected_read_write_context = """
|
||||
from __future__ import annotations
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def helper_method(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class MainClass:
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def main_method(self):
|
||||
return HelperClass(self.name).helper_method()
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{file_path.relative_to(file_path.parent)}
|
||||
expected_read_only_context = """
|
||||
"""
|
||||
expected_testgen_context = """
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def helper_method(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class MainClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
```
|
||||
|
||||
def main_method(self):
|
||||
return HelperClass(self.name).helper_method()
|
||||
"""
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
assert testgen_context.strip() == dedent(expected_testgen_context).strip()
|
||||
|
||||
|
||||
def test_class_method_dependencies() -> None:
|
||||
|
|
@ -122,8 +137,12 @@ def test_class_method_dependencies() -> None:
|
|||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
|
||||
class Graph:
|
||||
def __init__(self, vertices):
|
||||
self.graph = defaultdict(list)
|
||||
self.V = vertices # No. of vertices
|
||||
|
||||
def topologicalSortUtil(self, v, visited, stack):
|
||||
visited[v] = True
|
||||
|
|
@ -146,17 +165,7 @@ class Graph:
|
|||
return stack
|
||||
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{file_path.relative_to(file_path.parent)}
|
||||
from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
|
||||
class Graph:
|
||||
def __init__(self, vertices):
|
||||
self.graph = defaultdict(list)
|
||||
self.V = vertices # No. of vertices
|
||||
```
|
||||
"""
|
||||
expected_read_only_context = ""
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
||||
|
|
@ -398,6 +407,8 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
expected_read_write_context = """
|
||||
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
||||
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
def get_cache_or_call(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -455,6 +466,16 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
|
||||
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[_P, _R],
|
||||
duration: datetime.timedelta,
|
||||
) -> None:
|
||||
self.__wrapped__ = func
|
||||
self.__duration__ = duration
|
||||
self.__backend__ = AbstractCacheBackend()
|
||||
functools.update_wrapper(self, func)
|
||||
|
||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
\"\"\"
|
||||
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
|
||||
|
|
@ -487,8 +508,6 @@ _STORE_T = TypeVar("_STORE_T")
|
|||
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
||||
"""Interface for cache backends used by the persistent cache decorator."""
|
||||
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
def hash_key(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -535,16 +554,6 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
__wrapped__: Callable[_P, _R]
|
||||
__duration__: datetime.timedelta
|
||||
__backend__: _CacheBackendT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[_P, _R],
|
||||
duration: datetime.timedelta,
|
||||
) -> None:
|
||||
self.__wrapped__ = func
|
||||
self.__duration__ = duration
|
||||
self.__backend__ = AbstractCacheBackend()
|
||||
functools.update_wrapper(self, func)
|
||||
```
|
||||
'''
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
|
|
@ -598,10 +607,15 @@ class HelperClass:
|
|||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
"""
|
||||
|
|
@ -609,14 +623,9 @@ class HelperClass:
|
|||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
\"\"\"A class with a helper method.\"\"\"
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
class HelperClass:
|
||||
\"\"\"A helper class for MyClass.\"\"\"
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def __repr__(self):
|
||||
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
||||
return "HelperClass" + str(self.x)
|
||||
|
|
@ -679,23 +688,25 @@ class HelperClass:
|
|||
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
|
||||
expected_read_write_context = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
pass
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def __repr__(self):
|
||||
return "HelperClass" + str(self.x)
|
||||
```
|
||||
|
|
@ -756,15 +767,20 @@ class HelperClass:
|
|||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
|
||||
expected_read_write_context = """
|
||||
class MyClass:
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
\"\"\"Docstring for target method\"\"\"
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
"""
|
||||
class HelperClass:
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
"""
|
||||
expected_read_only_context = ""
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
|
@ -836,12 +852,18 @@ def test_repo_helper() -> None:
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
import requests
|
||||
import math
|
||||
import requests
|
||||
from globals import API_URL
|
||||
from utils import DataProcessor
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def process_data(self, raw_data: str) -> str:
|
||||
\"\"\"Process raw data by converting it to uppercase.\"\"\"
|
||||
return raw_data.upper()
|
||||
|
|
@ -869,8 +891,6 @@ def fetch_and_process_data():
|
|||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
import math
|
||||
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
|
|
@ -879,11 +899,6 @@ class DataProcessor:
|
|||
|
||||
number = 1
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
|
|
@ -914,6 +929,7 @@ def test_repo_helper_of_helper() -> None:
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
import requests
|
||||
from globals import API_URL
|
||||
|
|
@ -921,6 +937,11 @@ from utils import DataProcessor
|
|||
|
||||
class DataProcessor:
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def process_data(self, raw_data: str) -> str:
|
||||
\"\"\"Process raw data by converting it to uppercase.\"\"\"
|
||||
return raw_data.upper()
|
||||
|
|
@ -947,8 +968,6 @@ def fetch_and_transform_data():
|
|||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
import math
|
||||
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
|
|
@ -957,11 +976,6 @@ class DataProcessor:
|
|||
|
||||
number = 1
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
|
|
@ -973,8 +987,6 @@ if __name__ == "__main__":
|
|||
```
|
||||
```python:{path_to_transform_utils.relative_to(project_root)}
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
||||
def transform(self, data):
|
||||
self.data = data
|
||||
|
|
@ -1001,9 +1013,12 @@ def test_repo_helper_of_helper_same_class() -> None:
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
||||
def transform_using_own_method(self, data):
|
||||
return self.transform(data)
|
||||
|
|
@ -1012,6 +1027,11 @@ class DataTransformer:
|
|||
|
||||
class DataProcessor:
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def transform_data_own_method(self, data: str) -> str:
|
||||
\"\"\"Transform the processed data using own method\"\"\"
|
||||
return DataTransformer().transform_using_own_method(data)
|
||||
|
|
@ -1020,16 +1040,12 @@ class DataProcessor:
|
|||
expected_read_only_context = f"""
|
||||
```python:{path_to_transform_utils.relative_to(project_root)}
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
||||
def transform(self, data):
|
||||
self.data = data
|
||||
return self.data
|
||||
```
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
import math
|
||||
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
|
|
@ -1038,11 +1054,6 @@ class DataProcessor:
|
|||
|
||||
number = 1
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
|
|
@ -1069,9 +1080,12 @@ def test_repo_helper_of_helper_same_file() -> None:
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
||||
def transform_using_same_file_function(self, data):
|
||||
return update_data(data)
|
||||
|
|
@ -1080,23 +1094,21 @@ class DataTransformer:
|
|||
|
||||
class DataProcessor:
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def transform_data_same_file_function(self, data: str) -> str:
|
||||
\"\"\"Transform the processed data using a function from the same file\"\"\"
|
||||
return DataTransformer().transform_using_same_file_function(data)
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{path_to_transform_utils.relative_to(project_root)}
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
||||
|
||||
def update_data(data):
|
||||
return data + " updated"
|
||||
```
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
import math
|
||||
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
|
|
@ -1105,11 +1117,6 @@ class DataProcessor:
|
|||
|
||||
number = 1
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
|
|
@ -1135,6 +1142,8 @@ def test_repo_helper_all_same_file() -> None:
|
|||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
||||
def transform_using_own_method(self, data):
|
||||
return self.transform(data)
|
||||
|
|
@ -1150,9 +1159,7 @@ def update_data(data):
|
|||
expected_read_only_context = f"""
|
||||
```python:{path_to_transform_utils.relative_to(project_root)}
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
||||
|
||||
def transform(self, data):
|
||||
self.data = data
|
||||
return self.data
|
||||
|
|
@ -1179,11 +1186,17 @@ def test_repo_helper_circular_dependency() -> None:
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
expected_read_write_context = """
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
from code_to_optimize.code_directories.retriever.utils import DataProcessor
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def circular_dependency(self, data: str) -> str:
|
||||
\"\"\"Test circular dependency\"\"\"
|
||||
return DataTransformer().circular_dependency(data)
|
||||
|
|
@ -1191,6 +1204,8 @@ class DataProcessor:
|
|||
|
||||
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
|
||||
def circular_dependency(self, data):
|
||||
return DataProcessor().circular_dependency(data)
|
||||
|
|
@ -1199,8 +1214,6 @@ class DataTransformer:
|
|||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
import math
|
||||
|
||||
GLOBAL_VAR = 10
|
||||
|
||||
|
||||
|
|
@ -1209,20 +1222,10 @@ class DataProcessor:
|
|||
|
||||
number = 1
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
```
|
||||
```python:{path_to_transform_utils.relative_to(project_root)}
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
|
|
|
|||
|
|
@ -894,6 +894,7 @@ def test_test_libcst_code_replacement13() -> None:
|
|||
original_code = """class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.new_attribute = "Sorry i modified a dunder method"
|
||||
def new_function(self, value):
|
||||
return other_function(self.name)
|
||||
def new_function2(value):
|
||||
|
|
|
|||
|
|
@ -3,13 +3,13 @@ import site
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.code_utils import (
|
||||
cleanup_paths,
|
||||
file_name_from_test_module_name,
|
||||
file_path_from_module_name,
|
||||
get_all_function_names,
|
||||
get_imports_from_file,
|
||||
get_only_code_content,
|
||||
get_qualified_name,
|
||||
get_run_tmp_file,
|
||||
is_class_defined_in_file,
|
||||
|
|
@ -26,6 +26,7 @@ def multiple_existing_and_non_existing_files(tmp_path: Path) -> list[Path]:
|
|||
file.touch()
|
||||
return existing_files + non_existing_files
|
||||
|
||||
|
||||
def test_get_qualified_name_valid() -> None:
|
||||
module_name = "codeflash"
|
||||
full_qualified_name = "codeflash.utils.module"
|
||||
|
|
@ -47,6 +48,7 @@ def test_get_qualified_name_same_name() -> None:
|
|||
with pytest.raises(ValueError, match="is the same as codeflash"):
|
||||
get_qualified_name(module_name, full_qualified_name)
|
||||
|
||||
|
||||
# tests for module_name_from_file_path
|
||||
def test_module_name_from_file_path() -> None:
|
||||
project_root_path = Path("/Users/codeflashuser/PycharmProjects/codeflash")
|
||||
|
|
@ -91,6 +93,8 @@ def test_get_imports_from_file_with_file_path(tmp_path: Path) -> None:
|
|||
assert imports[0].names[0].name == "os"
|
||||
assert imports[1].module == "sys"
|
||||
assert imports[1].names[0].name == "path"
|
||||
|
||||
|
||||
def test_get_imports_from_file_with_file_string() -> None:
|
||||
file_string = "import os\nfrom sys import path\n"
|
||||
|
||||
|
|
@ -102,6 +106,7 @@ def test_get_imports_from_file_with_file_string() -> None:
|
|||
assert imports[1].module == "sys"
|
||||
assert imports[1].names[0].name == "path"
|
||||
|
||||
|
||||
def test_get_imports_from_file_with_file_ast() -> None:
|
||||
file_string = "import os\nfrom sys import path\n"
|
||||
file_ast = ast.parse(file_string)
|
||||
|
|
@ -114,6 +119,7 @@ def test_get_imports_from_file_with_file_ast() -> None:
|
|||
assert imports[1].module == "sys"
|
||||
assert imports[1].names[0].name == "path"
|
||||
|
||||
|
||||
def test_get_imports_from_file_with_syntax_error(caplog: pytest.LogCaptureFixture) -> None:
|
||||
file_string = "import os\nfrom sys import path\ninvalid syntax"
|
||||
|
||||
|
|
@ -173,6 +179,7 @@ async def bar():
|
|||
assert success is True
|
||||
assert function_names == ["foo", "bar"]
|
||||
|
||||
|
||||
def test_get_all_function_names_with_syntax_error(caplog: pytest.LogCaptureFixture) -> None:
|
||||
code = """
|
||||
def foo():
|
||||
|
|
@ -234,6 +241,7 @@ def test_get_run_tmp_file_reuses_temp_directory() -> None:
|
|||
assert tmp_file_path1.parent.name.startswith("codeflash_")
|
||||
assert tmp_file_path1.parent.exists()
|
||||
|
||||
|
||||
def test_path_belongs_to_site_packages_with_site_package_path(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
|
||||
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
|
||||
|
|
@ -241,6 +249,7 @@ def test_path_belongs_to_site_packages_with_site_package_path(monkeypatch: pytes
|
|||
file_path = Path("/usr/local/lib/python3.9/site-packages/some_package")
|
||||
assert path_belongs_to_site_packages(file_path) is True
|
||||
|
||||
|
||||
def test_path_belongs_to_site_packages_with_non_site_package_path(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
|
||||
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
|
||||
|
|
@ -248,6 +257,7 @@ def test_path_belongs_to_site_packages_with_non_site_package_path(monkeypatch: p
|
|||
file_path = Path("/usr/local/lib/python3.9/other_directory/some_package")
|
||||
assert path_belongs_to_site_packages(file_path) is False
|
||||
|
||||
|
||||
def test_path_belongs_to_site_packages_with_relative_path(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
|
||||
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
|
||||
|
|
@ -332,3 +342,51 @@ def test_cleanup_paths(multiple_existing_and_non_existing_files: list[Path]) ->
|
|||
cleanup_paths(multiple_existing_and_non_existing_files)
|
||||
for file in multiple_existing_and_non_existing_files:
|
||||
assert not file.exists()
|
||||
|
||||
|
||||
def test_get_only_code_content_with_docstring() -> None:
|
||||
"""Test function with only a docstring."""
|
||||
input_code = '''def foo():
|
||||
"""This is a docstring."""
|
||||
return 42'''
|
||||
expected = """def foo():
|
||||
return 42"""
|
||||
assert get_only_code_content(input_code) == expected
|
||||
|
||||
|
||||
def test_get_only_code_content_with_comments() -> None:
|
||||
"""Test function with only comments."""
|
||||
input_code = """def foo():
|
||||
# This is a comment
|
||||
return 42 # Another comment"""
|
||||
expected = """def foo():
|
||||
return 42"""
|
||||
assert get_only_code_content(input_code) == expected
|
||||
|
||||
|
||||
def test_get_only_code_content_with_docstring_and_comments() -> None:
|
||||
"""Test function with both docstring and comments."""
|
||||
input_code = '''def foo():
|
||||
"""This is a docstring."""
|
||||
# This is a comment
|
||||
return 42 # Another comment'''
|
||||
|
||||
expected = """def foo():
|
||||
return 42"""
|
||||
assert get_only_code_content(input_code) == expected
|
||||
|
||||
|
||||
def test_get_only_code_content_nested_functions() -> None:
|
||||
"""Test nested functions with docstrings."""
|
||||
input_code = '''def outer():
|
||||
"""Outer docstring."""
|
||||
def inner():
|
||||
"""Inner docstring."""
|
||||
return 42
|
||||
return inner()'''
|
||||
expected = """def outer():
|
||||
|
||||
def inner():
|
||||
return 42
|
||||
return inner()"""
|
||||
assert get_only_code_content(input_code) == expected
|
||||
|
|
|
|||
|
|
@ -40,8 +40,6 @@ def test_dunder_methods() -> None:
|
|||
|
||||
expected = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
|
|
@ -68,9 +66,7 @@ def test_dunder_methods_remove_docstring() -> None:
|
|||
|
||||
expected = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
|
@ -95,9 +91,7 @@ def test_class_remove_docstring() -> None:
|
|||
|
||||
expected = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
|
@ -124,9 +118,7 @@ def test_mixed_remove_docstring() -> None:
|
|||
|
||||
expected = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
|
@ -208,10 +200,6 @@ def test_multiple_top_level_targets() -> None:
|
|||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target1", "TestClass.target2"}, set())
|
||||
|
|
@ -659,11 +647,6 @@ def test_simplified_complete_implementation() -> None:
|
|||
class DataProcessor:
|
||||
\"\"\"A simple data processing class.\"\"\"
|
||||
|
||||
def __init__(self, data: Dict[str, Any]) -> None:
|
||||
self.data = data
|
||||
self._processed = False
|
||||
self.result = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DataProcessor(processed={self._processed})"
|
||||
|
||||
|
|
@ -672,17 +655,12 @@ def test_simplified_complete_implementation() -> None:
|
|||
processor = DataProcessor(sample_data)
|
||||
|
||||
class ResultHandler:
|
||||
def __init__(self, processor: DataProcessor):
|
||||
self.processor = processor
|
||||
self.cache = {}
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ResultHandler(cache_size={len(self.cache)})"
|
||||
|
||||
except Exception as e:
|
||||
class ResultHandler:
|
||||
def __init__(self):
|
||||
self.error = str(e)
|
||||
pass
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, set())
|
||||
|
|
@ -693,12 +671,6 @@ def test_simplified_complete_implementation_no_docstring() -> None:
|
|||
code = """
|
||||
class DataProcessor:
|
||||
\"\"\"A simple data processing class.\"\"\"
|
||||
|
||||
def __init__(self, data: Dict[str, Any]) -> None:
|
||||
self.data = data
|
||||
self._processed = False
|
||||
self.result = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DataProcessor(processed={self._processed})"
|
||||
|
||||
|
|
@ -732,9 +704,6 @@ def test_simplified_complete_implementation_no_docstring() -> None:
|
|||
processor = DataProcessor(sample_data)
|
||||
|
||||
class ResultHandler:
|
||||
def __init__(self, processor: DataProcessor):
|
||||
self.processor = processor
|
||||
self.cache = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ResultHandler(cache_size={len(self.cache)})"
|
||||
|
|
@ -758,8 +727,6 @@ def test_simplified_complete_implementation_no_docstring() -> None:
|
|||
|
||||
except Exception as e:
|
||||
class ResultHandler:
|
||||
def __init__(self):
|
||||
self.error = str(e)
|
||||
|
||||
def target_method(self, key: str) -> None:
|
||||
raise RuntimeError(f"Failed to initialize: {self.error}")
|
||||
|
|
@ -767,12 +734,6 @@ def test_simplified_complete_implementation_no_docstring() -> None:
|
|||
|
||||
expected = """
|
||||
class DataProcessor:
|
||||
|
||||
def __init__(self, data: Dict[str, Any]) -> None:
|
||||
self.data = data
|
||||
self._processed = False
|
||||
self.result = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DataProcessor(processed={self._processed})"
|
||||
|
||||
|
|
@ -781,17 +742,12 @@ def test_simplified_complete_implementation_no_docstring() -> None:
|
|||
processor = DataProcessor(sample_data)
|
||||
|
||||
class ResultHandler:
|
||||
def __init__(self, processor: DataProcessor):
|
||||
self.processor = processor
|
||||
self.cache = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ResultHandler(cache_size={len(self.cache)})"
|
||||
|
||||
except Exception as e:
|
||||
class ResultHandler:
|
||||
def __init__(self):
|
||||
self.error = str(e)
|
||||
pass
|
||||
"""
|
||||
|
||||
output = get_read_only_code(
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ def test_try_except_structure() -> None:
|
|||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_dunder_method() -> None:
|
||||
def test_init_method() -> None:
|
||||
code = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
|
|
@ -175,6 +175,30 @@ def test_dunder_method() -> None:
|
|||
"""
|
||||
result = get_read_writable_code(dedent(code), {"MyClass.target_method"})
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def target_method(self):
|
||||
return f"Value: {self.x}"
|
||||
""")
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
def test_dunder_method() -> None:
|
||||
code = """
|
||||
class MyClass:
|
||||
def __repr__(self):
|
||||
return "MyClass"
|
||||
|
||||
def other_method(self):
|
||||
return "other"
|
||||
|
||||
def target_method(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
result = get_read_writable_code(dedent(code), {"MyClass.target_method"})
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
||||
|
|
@ -183,7 +207,6 @@ def test_dunder_method() -> None:
|
|||
""")
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_no_targets_found() -> None:
|
||||
code = """
|
||||
class MyClass:
|
||||
|
|
|
|||
235
tests/test_init_optimization.py
Normal file
235
tests/test_init_optimization.py
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
from textwrap import dedent
|
||||
|
||||
import libcst as cst
|
||||
from codeflash.code_utils.code_replacer import merge_init_functions, replace_functions_in_file
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
|
||||
def test_basic_merge() -> None:
|
||||
original = """
|
||||
class MyClass:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
"""
|
||||
new = """
|
||||
class MyClass:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
"""
|
||||
result = merge_init_functions(
|
||||
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
)
|
||||
|
||||
expected = """
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
self.x = x
|
||||
self.y = y
|
||||
"""
|
||||
assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
|
||||
|
||||
def test_prevent_duplication() -> None:
|
||||
original = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
print("init")
|
||||
self.setup()
|
||||
"""
|
||||
new = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
print("init")
|
||||
self.setup()
|
||||
self.y = 2
|
||||
"""
|
||||
result = merge_init_functions(
|
||||
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
)
|
||||
|
||||
expected = """
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
print("init")
|
||||
self.setup()
|
||||
self.y = 2
|
||||
"""
|
||||
assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
|
||||
|
||||
def test_prevent_overwrite() -> None:
|
||||
original = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
self.y = 2
|
||||
"""
|
||||
new = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 2
|
||||
"""
|
||||
result = merge_init_functions(
|
||||
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
)
|
||||
|
||||
expected = """
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
self.y = 2
|
||||
"""
|
||||
assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
|
||||
|
||||
def test_complex_control_flow() -> None:
|
||||
original = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
with self.lock:
|
||||
self.setup()
|
||||
if self.debug:
|
||||
self.enable_logging()
|
||||
"""
|
||||
new = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
try:
|
||||
self.connect()
|
||||
except ConnectionError:
|
||||
self.fallback()
|
||||
"""
|
||||
result = merge_init_functions(
|
||||
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
)
|
||||
|
||||
expected = """
|
||||
def __init__(self):
|
||||
with self.lock:
|
||||
self.setup()
|
||||
if self.debug:
|
||||
self.enable_logging()
|
||||
try:
|
||||
self.connect()
|
||||
except ConnectionError:
|
||||
self.fallback()
|
||||
"""
|
||||
assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
|
||||
|
||||
def test_docstrings_and_comments() -> None:
|
||||
original = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
\"\"\"Original docstring.\"\"\"
|
||||
# Setup configuration
|
||||
self.config = {} # Empty config
|
||||
"""
|
||||
new = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
\"\"\"New docstring.\"\"\"
|
||||
# Initialize database
|
||||
self.db = None # Database connection
|
||||
"""
|
||||
result = merge_init_functions(
|
||||
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
)
|
||||
# TODO: handle docstrings differently
|
||||
expected = """
|
||||
def __init__(self):
|
||||
\"\"\"Original docstring.\"\"\"
|
||||
# Setup configuration
|
||||
self.config = {} # Empty config
|
||||
\"\"\"New docstring.\"\"\"
|
||||
# Initialize database
|
||||
self.db = None # Database connection
|
||||
"""
|
||||
assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
|
||||
|
||||
def test_type_annotations() -> None:
|
||||
original = """
|
||||
class MyClass:
|
||||
def __init__(self) -> None:
|
||||
self.x: int = 1
|
||||
self.y: str = "hello"
|
||||
"""
|
||||
new = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.y: str = "new hello"
|
||||
self.z: float = 2.0
|
||||
"""
|
||||
result = merge_init_functions(
|
||||
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
|
||||
)
|
||||
|
||||
expected = """
|
||||
def __init__(self) -> None:
|
||||
self.x: int = 1
|
||||
self.y: str = "hello"
|
||||
self.z: float = 2.0
|
||||
"""
|
||||
assert cst.Module([result]).code.strip() == dedent(expected).strip()
|
||||
|
||||
|
||||
# Tests for code replacement with init
|
||||
def test_merge_init_methods() -> None:
|
||||
optim_code = """class MyClass:
|
||||
def __init__(self):
|
||||
self.y = 2
|
||||
self.z = 3
|
||||
"""
|
||||
|
||||
original_code = """class MyClass:
|
||||
def __init__(self):
|
||||
self.y = 1
|
||||
self.setup()
|
||||
"""
|
||||
|
||||
expected = """class MyClass:
|
||||
def __init__(self):
|
||||
self.y = 1
|
||||
self.setup()
|
||||
self.z = 3
|
||||
"""
|
||||
|
||||
result = replace_functions_in_file(
|
||||
source_code=original_code,
|
||||
original_function_names=[],
|
||||
optimized_code=optim_code,
|
||||
preexisting_objects=[("__init__", [FunctionParent(name="MyClass", type="ClassDef")])],
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_init_is_function_to_optimize() -> None:
|
||||
optim_code = """class MyClass:
|
||||
def __init__(self):
|
||||
self.y = 2
|
||||
self.z = 3
|
||||
"""
|
||||
|
||||
original_code = """class MyClass:
|
||||
def __init__(self):
|
||||
self.y = 1
|
||||
self.setup()
|
||||
"""
|
||||
|
||||
expected = """class MyClass:
|
||||
def __init__(self):
|
||||
self.y = 2
|
||||
self.z = 3
|
||||
"""
|
||||
# In this scenario, we leave the mutation check to the usual FTO behaviour check.
|
||||
result = replace_functions_in_file(
|
||||
source_code=original_code,
|
||||
original_function_names=["MyClass.__init__"],
|
||||
optimized_code=optim_code,
|
||||
preexisting_objects=[("__init__", [FunctionParent(name="MyClass", type="ClassDef")])],
|
||||
)
|
||||
assert result == expected
|
||||
Loading…
Reference in a new issue