draft PR for init caching. no instrumentation checks implemented yet

This commit is contained in:
Alvin Ryanputra 2025-01-13 17:01:52 -08:00
parent d8ac58c5bb
commit 8de9cebe90
10 changed files with 814 additions and 182 deletions

View file

@ -9,6 +9,7 @@ import libcst as cst
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from codeflash.code_utils.code_utils import cst_to_code, get_only_code_content
from codeflash.models.models import FunctionParent
if TYPE_CHECKING:
@ -51,10 +52,13 @@ class OptimFunctionCollector(cst.CSTVisitor):
self.new_functions: list[cst.FunctionDef] = []
self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
self.current_class = None
self.modified_init_functions: dict[str, cst.FunctionDef] = {}
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
if (self.current_class, node.name.value) in self.function_names:
self.modified_functions[(self.current_class, node.name.value)] = node
elif self.current_class and node.name.value == "__init__":
self.modified_init_functions[self.current_class] = node
elif (
self.preexisting_objects
and (node.name.value, []) not in self.preexisting_objects
@ -76,6 +80,7 @@ class OptimFunctionCollector(cst.CSTVisitor):
and (child_node.name.value, parents) not in self.preexisting_objects
):
self.new_class_functions[node.name.value].append(child_node)
return True
def leave_ClassDef(self, node: cst.ClassDef) -> None:
@ -89,11 +94,15 @@ class OptimFunctionReplacer(cst.CSTTransformer):
modified_functions: dict[tuple[str | None, str], cst.FunctionDef] = None,
new_functions: list[cst.FunctionDef] = None,
new_class_functions: dict[str, list[cst.FunctionDef]] = None,
modified_init_functions: dict[str, cst.FunctionDef] = None,
) -> None:
super().__init__()
self.modified_functions = modified_functions if modified_functions is not None else {}
self.new_functions = new_functions if new_functions is not None else []
self.new_class_functions = new_class_functions if new_class_functions is not None else defaultdict(list)
self.modified_init_functions: dict[str, cst.FunctionDef] = (
modified_init_functions if modified_init_functions is not None else {}
)
self.current_class = None
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
@ -102,7 +111,15 @@ class OptimFunctionReplacer(cst.CSTTransformer):
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
if (self.current_class, original_node.name.value) in self.modified_functions:
node = self.modified_functions[(self.current_class, original_node.name.value)]
if get_only_code_content(cst_to_code(original_node)) == get_only_code_content(cst_to_code(node)):
return original_node # Code was unchanged, so don't modify docstrings / comments
return updated_node.with_changes(body=node.body, decorators=node.decorators)
if original_node.name.value == "__init__" and self.current_class in self.modified_init_functions:
if get_only_code_content(cst_to_code(original_node)) == get_only_code_content(
cst_to_code(self.modified_init_functions[self.current_class])
):
return original_node # Code was unchanged, so don't modify docstrings / comments
return merge_init_functions(updated_node, self.modified_init_functions[self.current_class])
return updated_node
@ -145,6 +162,97 @@ class OptimFunctionReplacer(cst.CSTTransformer):
return node
class AttributeCollector(cst.CSTVisitor):
"""Collects all self.attribute mentions in a CST."""
def __init__(self):
super().__init__()
self.attributes: set[str] = set()
def visit_Attribute(self, node: cst.Attribute) -> bool:
"""Record any self.attribute access."""
if isinstance(node.value, cst.Name) and node.value.value == "self":
self.attributes.add(node.attr.value)
return True
class AssignmentCollector(cst.CSTVisitor):
"""Collects attributes being assigned to in a CST."""
def __init__(self):
super().__init__()
self.assigned_attrs: set[str] = set()
def visit_Assign(self, node: cst.Assign) -> bool:
"""Check regular assignments like self.x = ..."""
for target in node.targets:
if (
isinstance(target.target, cst.Attribute)
and isinstance(target.target.value, cst.Name)
and target.target.value.value == "self"
):
self.assigned_attrs.add(target.target.attr.value)
return True
def visit_AnnAssign(self, node: cst.AnnAssign) -> bool:
"""Check annotated assignments like self.x: str = ..."""
if (
isinstance(node.target, cst.Attribute)
and isinstance(node.target.value, cst.Name)
and node.target.value.value == "self"
):
self.assigned_attrs.add(node.target.attr.value)
return True
def visit_AugAssign(self, node: cst.AugAssign) -> bool:
"""Check augmented assignments like self.x += ..."""
if (
isinstance(node.target, cst.Attribute)
and isinstance(node.target.value, cst.Name)
and node.target.value.value == "self"
):
self.assigned_attrs.add(node.target.attr.value)
return True
def merge_init_functions(original_init: cst.FunctionDef, new_init: cst.FunctionDef) -> cst.FunctionDef:
"""Merges two __init__ function definitions. Collects all self.attribute mentions
from the original init, then filters out statements from the new init that
assign to those attributes (but allows reading them).
Args:
original_init: The original __init__ function to preserve
new_init: The new __init__ function whose body will be filtered and appended
Returns:
A merged FunctionDef
"""
# Collect all self.attribute mentions from original init
collector = AttributeCollector()
original_init.visit(collector)
existing_attrs = collector.attributes
# Get set of existing statements as strings. # This should just be in terms of code, not comments?
original_stmts = {cst.Module([stmt]).code for stmt in original_init.body.body}
# Filter new init body statements
filtered_body = []
for stmt in new_init.body.body:
if cst.Module([stmt]).code in original_stmts:
continue
# Check for assignments to existing attributes
assign_collector = AssignmentCollector()
stmt.visit(assign_collector)
# Keep statement if it doesn't assign to any existing attributes
if not assign_collector.assigned_attrs.intersection(existing_attrs):
filtered_body.append(stmt)
# Merge bodies using with_changes
return original_init.with_changes(
body=original_init.body.with_changes(body=original_init.body.body + tuple(filtered_body))
)
def replace_functions_in_file(
source_code: str,
original_function_names: list[str],
@ -173,6 +281,7 @@ def replace_functions_in_file(
modified_functions=visitor.modified_functions,
new_functions=visitor.new_functions,
new_class_functions=visitor.new_class_functions,
modified_init_functions=visitor.modified_init_functions,
)
original_module = cst.parse_module(source_code)
modified_tree = original_module.visit(transformer)
@ -191,7 +300,7 @@ def replace_functions_and_add_imports(
return add_needed_imports_from_module(
optimized_code,
replace_functions_in_file(source_code, function_names, optimized_code, preexisting_objects),
file_path_of_module_with_function_to_optimize,
module_abspath,
module_abspath,
project_root_path,
)

View file

@ -7,9 +7,37 @@ from functools import lru_cache
from pathlib import Path
from tempfile import TemporaryDirectory
import libcst as cst
from codeflash.cli_cmds.console import logger
def cst_to_code(node: cst.CSTNode) -> str:
return cst.Module([node]).code.strip()
def get_only_code_content(code: str) -> str:
"""Extract just the code content from code, ignoring comments and docstrings.
Args:
code: Source code as a string
Returns:
String of code with comments and docstrings removed
"""
# Parse into AST - this automatically strips comments
tree = ast.parse(code)
# Remove docstrings from function
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and (ast.get_docstring(node)):
# If first element is docstring, remove it
node.body = node.body[1:]
# Unparse back to source code for comparison
return ast.unparse(tree)
def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
if not full_qualified_name:
raise ValueError("full_qualified_name cannot be empty")
@ -80,6 +108,7 @@ def get_all_function_names(code: str) -> tuple[bool, list[str]]:
def get_run_tmp_file(file_path: Path) -> Path:
if not hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
logger.info(f"Created new temp directory for codeflash: {Path(get_run_tmp_file.tmpdir.name)!s}")
return Path(get_run_tmp_file.tmpdir.name) / file_path

View file

@ -26,9 +26,14 @@ def get_code_optimization_context(
helpers_of_fto, helpers_of_fto_fqn, helpers_of_fto_obj_list = get_file_path_to_helper_functions_dict(
{function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path
)
helpers_of_helpers, helpers_of_helpers_fqn, _ = get_file_path_to_helper_functions_dict(
helpers_of_fto, project_root_path
)
print("helpers_of_fto")
print(helpers_of_fto)
print("helpers_of_helpers")
print(helpers_of_helpers)
# Add function to optimize
helpers_of_fto[function_to_optimize.file_path].add(function_to_optimize.qualified_name)
helpers_of_fto_fqn[function_to_optimize.file_path].add(
@ -45,13 +50,15 @@ def get_code_optimization_context(
project_root_path,
remove_docstrings=False,
)
testgen_context_code = get_all_testgen_context(helpers_of_fto, helpers_of_fto_fqn, project_root_path)
# Handle token limits
tokenizer = tiktoken.encoding_for_model("gpt-4o")
final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
if final_read_writable_tokens > token_limit:
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
# if len(tokenizer.encode(testgen_context_code.code)) > token_limit:
# raise ValueError("Testgen context has exceeded token limit, cannot proceed")
# Setup preexisting objects for code replacer TODO: should remove duplicates
preexisting_objects = list(
chain(
@ -68,6 +75,7 @@ def get_code_optimization_context(
read_only_context_code=read_only_code_markdown.markdown,
helper_functions=helpers_of_fto_obj_list,
preexisting_objects=preexisting_objects,
testgen_context_code=testgen_context_code.code,
)
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
@ -90,6 +98,7 @@ def get_code_optimization_context(
read_only_context_code=read_only_code_no_docstring_markdown.markdown,
helper_functions=helpers_of_fto_obj_list,
preexisting_objects=preexisting_objects,
testgen_context_code=testgen_context_code.code,
)
logger.debug("Code context has exceeded token limit, removing read-only code")
@ -99,9 +108,40 @@ def get_code_optimization_context(
read_only_context_code="",
helper_functions=helpers_of_fto_obj_list,
preexisting_objects=preexisting_objects,
testgen_context_code=testgen_context_code.code,
)
def get_all_testgen_context(
helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], project_root_path: Path
) -> CodeString:
final_testgen_context = ""
# Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files
for file_path, qualified_function_names in helpers_of_fto.items():
try:
original_code = file_path.read_text("utf8")
except Exception as e:
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
testgen_context_code = get_testgen_context(original_code, qualified_function_names)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
if testgen_context_code:
final_testgen_context += f"\n{testgen_context_code}"
final_testgen_context = add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=final_testgen_context,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=helpers_of_fto_fqn[file_path],
)
return CodeString(code=final_testgen_context)
def get_all_read_writable_code(
helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], project_root_path: Path
) -> CodeString:
@ -167,19 +207,18 @@ def get_all_read_only_code_context(
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
read_only_code_with_imports = CodeString(
code=add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=read_only_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=helpers_of_fto_fqn[file_path] | helpers_of_helpers_fqn[file_path],
),
file_path=file_path.relative_to(project_root_path),
)
if read_only_code_with_imports.code:
if read_only_code.strip():
read_only_code_with_imports = CodeString(
code=add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=read_only_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=helpers_of_fto_fqn[file_path] | helpers_of_helpers_fqn[file_path],
),
file_path=file_path.relative_to(project_root_path),
)
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
# Extract code from file paths containing helpers of helpers
@ -197,18 +236,18 @@ def get_all_read_only_code_context(
logger.debug(f"Error while getting read-only code: {e}")
continue
read_only_code_with_imports = CodeString(
code=add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=read_only_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=helpers_of_helpers_no_overlap_fqn[file_path],
),
file_path=file_path.relative_to(project_root_path),
)
if read_only_code_with_imports.code:
if read_only_code.strip():
read_only_code_with_imports = CodeString(
code=add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=read_only_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=helpers_of_helpers_no_overlap_fqn[file_path],
),
file_path=file_path.relative_to(project_root_path),
)
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
return read_only_code_markdown
@ -327,9 +366,10 @@ def prune_cst_for_read_writable_code(
if qualified_name in target_functions:
new_body.append(stmt)
found_target = True
elif stmt.name.value == "__init__":
new_body.append(stmt) # enable __init__ optimizations
# If no target functions found, remove the class entirely
if not new_body:
if not new_body or not found_target:
return None, False
return node.with_changes(body=cst.IndentedBlock(body=new_body)), found_target
@ -408,7 +448,7 @@ def prune_cst_for_read_only_code(
if qualified_name in target_functions:
return None, True
# Keep only dunder methods
if is_dunder_method(node.name.value):
if is_dunder_method(node.name.value) and node.name.value != "__init__":
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
new_body = remove_docstring_from_body(node.body)
return node.with_changes(body=new_body), False
@ -501,3 +541,109 @@ def get_read_only_code(
if filtered_node and isinstance(filtered_node, cst.Module):
return str(filtered_node.code)
return ""
def prune_cst_for_testgen_context(
node: cst.CSTNode, target_functions: set[str], prefix: str = "", remove_docstrings: bool = False
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node for read-only context:
Returns:
(filtered_node, found_target):
filtered_node: The modified CST node or None if it should be removed.
found_target: True if a target function was found in this node's subtree.
"""
if isinstance(node, (cst.Import, cst.ImportFrom)):
return None, False
if isinstance(node, cst.FunctionDef):
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
if qualified_name in target_functions:
return node, True
# Keep only dunder methods
if is_dunder_method(node.name.value):
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
new_body = remove_docstring_from_body(node.body)
return node.with_changes(body=new_body), False
return node, False
return None, False
if isinstance(node, cst.ClassDef):
# Do not recurse into nested classes
if prefix:
return None, False
# Assuming always an IndentedBlock
if not isinstance(node.body, cst.IndentedBlock):
raise ValueError("ClassDef body is not an IndentedBlock")
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
# First pass: detect if there is a target function in the class
found_in_class = False
new_class_body: list[CSTNode] = []
for stmt in node.body.body:
filtered, found_target = prune_cst_for_testgen_context(
stmt, target_functions, class_prefix, remove_docstrings=remove_docstrings
)
found_in_class |= found_target
if filtered:
new_class_body.append(filtered)
if not found_in_class:
return None, False
if remove_docstrings:
return node.with_changes(
body=remove_docstring_from_body(node.body.with_changes(body=new_class_body))
) if new_class_body else None, True
return node.with_changes(body=node.body.with_changes(body=new_class_body)) if new_class_body else None, True
# For other nodes, keep the node and recursively filter children
section_names = get_section_names(node)
if not section_names:
return node, False
updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {}
found_any_target = False
for section in section_names:
original_content = getattr(node, section, None)
if isinstance(original_content, (list, tuple)):
new_children = []
section_found_target = False
for child in original_content:
filtered, found_target = prune_cst_for_testgen_context(
child, target_functions, prefix, remove_docstrings=remove_docstrings
)
if filtered:
new_children.append(filtered)
section_found_target |= found_target
if section_found_target or new_children:
found_any_target |= section_found_target
updates[section] = new_children
elif original_content is not None:
filtered, found_target = prune_cst_for_testgen_context(
original_content, target_functions, prefix, remove_docstrings=remove_docstrings
)
found_any_target |= found_target
if filtered:
updates[section] = filtered
if updates:
return (node.with_changes(**updates), found_any_target)
return None, False
def get_testgen_context(code: str, target_functions: set[str], remove_docstrings: bool = False) -> str:
"""Creates testgen_context. Similar to get_read_only_code, except the target functions are included."""
module = cst.parse_module(code)
filtered_node, found_target = prune_cst_for_testgen_context(
module, target_functions, remove_docstrings=remove_docstrings
)
if not found_target:
raise ValueError("No target functions found in the provided code")
if filtered_node and isinstance(filtered_node, cst.Module):
return str(filtered_node.code)
return ""

View file

@ -71,6 +71,78 @@ from codeflash.code_utils.code_utils import path_belongs_to_site_packages
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
def belongs_to_class(name: Name, class_name: str) -> bool:
"""Check if the given name belongs to the specified class."""
if name.full_name and name.full_name.startswith(f"{name.module_name}.{class_name}."):
return True
return False
def belongs_to_function(name: Name, function_name: str) -> bool:
"""Check if the given name belongs to the specified function"""
if name.full_name and name.full_name.startswith(name.module_name):
subname: str = name.full_name.replace(name.module_name, "", 1)
else:
return False
# The name is defined inside the function or is the function itself
return f".{function_name}." in subname or f".{function_name}" == subname
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
class Source:
full_name: str
definition: Name
source_code: str
'''
dst_module = '''def belongs_to_function(name: Name, function_name: str) -> bool:
"""Check if the given name belongs to the specified function"""
if name.full_name and name.full_name.startswith(name.module_name):
subname: str = name.full_name.replace(name.module_name, "", 1)
else:
return False
# The name is defined inside the function or is the function itself
return f".{function_name}." in subname or f".{function_name}" == subname
'''
expected = '''from jedi.api.classes import Name
def belongs_to_function(name: Name, function_name: str) -> bool:
"""Check if the given name belongs to the specified function"""
if name.full_name and name.full_name.startswith(name.module_name):
subname: str = name.full_name.replace(name.module_name, "", 1)
else:
return False
# The name is defined inside the function or is the function itself
return f".{function_name}." in subname or f".{function_name}" == subname
'''
src_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
dst_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
project_root = Path("/home/roger/repos/codeflash")
new_module = add_needed_imports_from_module(src_module, dst_module, src_path, dst_path, project_root)
assert new_module == expected
def test_add_needed_imports_from_module_with_if() -> None:
src_module = '''import ast
import logging
import os
from typing import Union
if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard
import jedi
import tiktoken
from jedi.api.classes import Name
from pydantic.dataclasses import dataclass
from codeflash.code_utils.code_extractor import get_code, get_code_no_skeleton
from codeflash.code_utils.code_utils import path_belongs_to_site_packages
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
def belongs_to_class(name: Name, class_name: str) -> bool:
"""Check if the given name belongs to the specified class."""
if name.full_name and name.full_name.startswith(f"{name.module_name}.{class_name}."):

View file

@ -73,38 +73,53 @@ def test_code_replacement10() -> None:
)
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
read_write_context, read_only_context, testgen_context = (
code_ctx.read_writable_code,
code_ctx.read_only_context_code,
code_ctx.testgen_context_code,
)
expected_read_write_context = """
from __future__ import annotations
class HelperClass:
def __init__(self, name):
self.name = name
def helper_method(self):
return self.name
class MainClass:
def __init__(self, name):
self.name = name
def main_method(self):
return HelperClass(self.name).helper_method()
"""
expected_read_only_context = f"""
```python:{file_path.relative_to(file_path.parent)}
expected_read_only_context = """
"""
expected_testgen_context = """
from __future__ import annotations
class HelperClass:
def __init__(self, name):
self.name = name
def helper_method(self):
return self.name
class MainClass:
def __init__(self, name):
self.name = name
```
def main_method(self):
return HelperClass(self.name).helper_method()
"""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
assert testgen_context.strip() == dedent(expected_testgen_context).strip()
def test_class_method_dependencies() -> None:
@ -122,8 +137,12 @@ def test_class_method_dependencies() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
from __future__ import annotations
from collections import defaultdict
class Graph:
def __init__(self, vertices):
self.graph = defaultdict(list)
self.V = vertices # No. of vertices
def topologicalSortUtil(self, v, visited, stack):
visited[v] = True
@ -146,17 +165,7 @@ class Graph:
return stack
"""
expected_read_only_context = f"""
```python:{file_path.relative_to(file_path.parent)}
from __future__ import annotations
from collections import defaultdict
class Graph:
def __init__(self, vertices):
self.graph = defaultdict(list)
self.V = vertices # No. of vertices
```
"""
expected_read_only_context = ""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
@ -398,6 +407,8 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
expected_read_write_context = """
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
def __init__(self) -> None: ...
def get_cache_or_call(
self,
*,
@ -455,6 +466,16 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
def __init__(
self,
func: Callable[_P, _R],
duration: datetime.timedelta,
) -> None:
self.__wrapped__ = func
self.__duration__ = duration
self.__backend__ = AbstractCacheBackend()
functools.update_wrapper(self, func)
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
\"\"\"
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
@ -487,8 +508,6 @@ _STORE_T = TypeVar("_STORE_T")
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
"""Interface for cache backends used by the persistent cache decorator."""
def __init__(self) -> None: ...
def hash_key(
self,
*,
@ -535,16 +554,6 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
__wrapped__: Callable[_P, _R]
__duration__: datetime.timedelta
__backend__: _CacheBackendT
def __init__(
self,
func: Callable[_P, _R],
duration: datetime.timedelta,
) -> None:
self.__wrapped__ = func
self.__duration__ = duration
self.__backend__ = AbstractCacheBackend()
functools.update_wrapper(self, func)
```
'''
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
@ -598,10 +607,15 @@ class HelperClass:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
class MyClass:
def __init__(self):
self.x = 1
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def helper_method(self):
return self.x
"""
@ -609,14 +623,9 @@ class HelperClass:
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
\"\"\"A class with a helper method.\"\"\"
def __init__(self):
self.x = 1
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
@ -679,23 +688,25 @@ class HelperClass:
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
expected_read_write_context = """
class MyClass:
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
class HelperClass:
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def helper_method(self):
return self.x
"""
expected_read_only_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def __init__(self):
self.x = 1
pass
class HelperClass:
def __init__(self):
self.x = 1
def __repr__(self):
return "HelperClass" + str(self.x)
```
@ -756,15 +767,20 @@ class HelperClass:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
expected_read_write_context = """
class MyClass:
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
class MyClass:
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
class HelperClass:
def helper_method(self):
return self.x
"""
class HelperClass:
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def helper_method(self):
return self.x
"""
expected_read_only_context = ""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
@ -836,12 +852,18 @@ def test_repo_helper() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
import requests
import math
import requests
from globals import API_URL
from utils import DataProcessor
class DataProcessor:
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def process_data(self, raw_data: str) -> str:
\"\"\"Process raw data by converting it to uppercase.\"\"\"
return raw_data.upper()
@ -869,8 +891,6 @@ def fetch_and_process_data():
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
import math
GLOBAL_VAR = 10
@ -879,11 +899,6 @@ class DataProcessor:
number = 1
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
@ -914,6 +929,7 @@ def test_repo_helper_of_helper() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
import math
from transform_utils import DataTransformer
import requests
from globals import API_URL
@ -921,6 +937,11 @@ from utils import DataProcessor
class DataProcessor:
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def process_data(self, raw_data: str) -> str:
\"\"\"Process raw data by converting it to uppercase.\"\"\"
return raw_data.upper()
@ -947,8 +968,6 @@ def fetch_and_transform_data():
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
import math
GLOBAL_VAR = 10
@ -957,11 +976,6 @@ class DataProcessor:
number = 1
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
@ -973,8 +987,6 @@ if __name__ == "__main__":
```
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def __init__(self):
self.data = None
def transform(self, data):
self.data = data
@ -1001,9 +1013,12 @@ def test_repo_helper_of_helper_same_class() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
import math
from transform_utils import DataTransformer
class DataTransformer:
def __init__(self):
self.data = None
def transform_using_own_method(self, data):
return self.transform(data)
@ -1012,6 +1027,11 @@ class DataTransformer:
class DataProcessor:
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def transform_data_own_method(self, data: str) -> str:
\"\"\"Transform the processed data using own method\"\"\"
return DataTransformer().transform_using_own_method(data)
@ -1020,16 +1040,12 @@ class DataProcessor:
expected_read_only_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def __init__(self):
self.data = None
def transform(self, data):
self.data = data
return self.data
```
```python:{path_to_utils.relative_to(project_root)}
import math
GLOBAL_VAR = 10
@ -1038,11 +1054,6 @@ class DataProcessor:
number = 1
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
@ -1069,9 +1080,12 @@ def test_repo_helper_of_helper_same_file() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
import math
from transform_utils import DataTransformer
class DataTransformer:
def __init__(self):
self.data = None
def transform_using_same_file_function(self, data):
return update_data(data)
@ -1080,23 +1094,21 @@ class DataTransformer:
class DataProcessor:
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def transform_data_same_file_function(self, data: str) -> str:
\"\"\"Transform the processed data using a function from the same file\"\"\"
return DataTransformer().transform_using_same_file_function(data)
"""
expected_read_only_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def __init__(self):
self.data = None
def update_data(data):
return data + " updated"
```
```python:{path_to_utils.relative_to(project_root)}
import math
GLOBAL_VAR = 10
@ -1105,11 +1117,6 @@ class DataProcessor:
number = 1
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
@ -1135,6 +1142,8 @@ def test_repo_helper_all_same_file() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
class DataTransformer:
def __init__(self):
self.data = None
def transform_using_own_method(self, data):
return self.transform(data)
@ -1150,9 +1159,7 @@ def update_data(data):
expected_read_only_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def __init__(self):
self.data = None
def transform(self, data):
self.data = data
return self.data
@ -1179,11 +1186,17 @@ def test_repo_helper_circular_dependency() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
expected_read_write_context = """
import math
from transform_utils import DataTransformer
from code_to_optimize.code_directories.retriever.utils import DataProcessor
class DataProcessor:
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def circular_dependency(self, data: str) -> str:
\"\"\"Test circular dependency\"\"\"
return DataTransformer().circular_dependency(data)
@ -1191,6 +1204,8 @@ class DataProcessor:
class DataTransformer:
def __init__(self):
self.data = None
def circular_dependency(self, data):
return DataProcessor().circular_dependency(data)
@ -1199,8 +1214,6 @@ class DataTransformer:
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
import math
GLOBAL_VAR = 10
@ -1209,20 +1222,10 @@ class DataProcessor:
number = 1
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def __init__(self):
self.data = None
```
"""

View file

@ -894,6 +894,7 @@ def test_test_libcst_code_replacement13() -> None:
original_code = """class NewClass:
def __init__(self, name):
self.name = name
self.new_attribute = "Sorry i modified a dunder method"
def new_function(self, value):
return other_function(self.name)
def new_function2(value):

View file

@ -3,13 +3,13 @@ import site
from pathlib import Path
import pytest
from codeflash.code_utils.code_utils import (
cleanup_paths,
file_name_from_test_module_name,
file_path_from_module_name,
get_all_function_names,
get_imports_from_file,
get_only_code_content,
get_qualified_name,
get_run_tmp_file,
is_class_defined_in_file,
@ -26,6 +26,7 @@ def multiple_existing_and_non_existing_files(tmp_path: Path) -> list[Path]:
file.touch()
return existing_files + non_existing_files
def test_get_qualified_name_valid() -> None:
module_name = "codeflash"
full_qualified_name = "codeflash.utils.module"
@ -47,6 +48,7 @@ def test_get_qualified_name_same_name() -> None:
with pytest.raises(ValueError, match="is the same as codeflash"):
get_qualified_name(module_name, full_qualified_name)
# tests for module_name_from_file_path
def test_module_name_from_file_path() -> None:
project_root_path = Path("/Users/codeflashuser/PycharmProjects/codeflash")
@ -91,6 +93,8 @@ def test_get_imports_from_file_with_file_path(tmp_path: Path) -> None:
assert imports[0].names[0].name == "os"
assert imports[1].module == "sys"
assert imports[1].names[0].name == "path"
def test_get_imports_from_file_with_file_string() -> None:
file_string = "import os\nfrom sys import path\n"
@ -102,6 +106,7 @@ def test_get_imports_from_file_with_file_string() -> None:
assert imports[1].module == "sys"
assert imports[1].names[0].name == "path"
def test_get_imports_from_file_with_file_ast() -> None:
file_string = "import os\nfrom sys import path\n"
file_ast = ast.parse(file_string)
@ -114,6 +119,7 @@ def test_get_imports_from_file_with_file_ast() -> None:
assert imports[1].module == "sys"
assert imports[1].names[0].name == "path"
def test_get_imports_from_file_with_syntax_error(caplog: pytest.LogCaptureFixture) -> None:
file_string = "import os\nfrom sys import path\ninvalid syntax"
@ -173,6 +179,7 @@ async def bar():
assert success is True
assert function_names == ["foo", "bar"]
def test_get_all_function_names_with_syntax_error(caplog: pytest.LogCaptureFixture) -> None:
code = """
def foo():
@ -234,6 +241,7 @@ def test_get_run_tmp_file_reuses_temp_directory() -> None:
assert tmp_file_path1.parent.name.startswith("codeflash_")
assert tmp_file_path1.parent.exists()
def test_path_belongs_to_site_packages_with_site_package_path(monkeypatch: pytest.MonkeyPatch) -> None:
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
@ -241,6 +249,7 @@ def test_path_belongs_to_site_packages_with_site_package_path(monkeypatch: pytes
file_path = Path("/usr/local/lib/python3.9/site-packages/some_package")
assert path_belongs_to_site_packages(file_path) is True
def test_path_belongs_to_site_packages_with_non_site_package_path(monkeypatch: pytest.MonkeyPatch) -> None:
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
@ -248,6 +257,7 @@ def test_path_belongs_to_site_packages_with_non_site_package_path(monkeypatch: p
file_path = Path("/usr/local/lib/python3.9/other_directory/some_package")
assert path_belongs_to_site_packages(file_path) is False
def test_path_belongs_to_site_packages_with_relative_path(monkeypatch: pytest.MonkeyPatch) -> None:
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
@ -332,3 +342,51 @@ def test_cleanup_paths(multiple_existing_and_non_existing_files: list[Path]) ->
cleanup_paths(multiple_existing_and_non_existing_files)
for file in multiple_existing_and_non_existing_files:
assert not file.exists()
def test_get_only_code_content_with_docstring() -> None:
"""Test function with only a docstring."""
input_code = '''def foo():
"""This is a docstring."""
return 42'''
expected = """def foo():
return 42"""
assert get_only_code_content(input_code) == expected
def test_get_only_code_content_with_comments() -> None:
"""Test function with only comments."""
input_code = """def foo():
# This is a comment
return 42 # Another comment"""
expected = """def foo():
return 42"""
assert get_only_code_content(input_code) == expected
def test_get_only_code_content_with_docstring_and_comments() -> None:
"""Test function with both docstring and comments."""
input_code = '''def foo():
"""This is a docstring."""
# This is a comment
return 42 # Another comment'''
expected = """def foo():
return 42"""
assert get_only_code_content(input_code) == expected
def test_get_only_code_content_nested_functions() -> None:
"""Test nested functions with docstrings."""
input_code = '''def outer():
"""Outer docstring."""
def inner():
"""Inner docstring."""
return 42
return inner()'''
expected = """def outer():
def inner():
return 42
return inner()"""
assert get_only_code_content(input_code) == expected

View file

@ -40,8 +40,6 @@ def test_dunder_methods() -> None:
expected = """
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
@ -68,9 +66,7 @@ def test_dunder_methods_remove_docstring() -> None:
expected = """
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
"""
@ -95,9 +91,7 @@ def test_class_remove_docstring() -> None:
expected = """
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
"""
@ -124,9 +118,7 @@ def test_mixed_remove_docstring() -> None:
expected = """
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
"""
@ -208,10 +200,6 @@ def test_multiple_top_level_targets() -> None:
"""
expected = """
class TestClass:
def __init__(self):
self.x = 42
"""
output = get_read_only_code(dedent(code), {"TestClass.target1", "TestClass.target2"}, set())
@ -659,11 +647,6 @@ def test_simplified_complete_implementation() -> None:
class DataProcessor:
\"\"\"A simple data processing class.\"\"\"
def __init__(self, data: Dict[str, Any]) -> None:
self.data = data
self._processed = False
self.result = None
def __repr__(self) -> str:
return f"DataProcessor(processed={self._processed})"
@ -672,17 +655,12 @@ def test_simplified_complete_implementation() -> None:
processor = DataProcessor(sample_data)
class ResultHandler:
def __init__(self, processor: DataProcessor):
self.processor = processor
self.cache = {}
def __str__(self) -> str:
return f"ResultHandler(cache_size={len(self.cache)})"
except Exception as e:
class ResultHandler:
def __init__(self):
self.error = str(e)
pass
"""
output = get_read_only_code(dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, set())
@ -693,12 +671,6 @@ def test_simplified_complete_implementation_no_docstring() -> None:
code = """
class DataProcessor:
\"\"\"A simple data processing class.\"\"\"
def __init__(self, data: Dict[str, Any]) -> None:
self.data = data
self._processed = False
self.result = None
def __repr__(self) -> str:
return f"DataProcessor(processed={self._processed})"
@ -732,9 +704,6 @@ def test_simplified_complete_implementation_no_docstring() -> None:
processor = DataProcessor(sample_data)
class ResultHandler:
def __init__(self, processor: DataProcessor):
self.processor = processor
self.cache = {}
def __str__(self) -> str:
return f"ResultHandler(cache_size={len(self.cache)})"
@ -758,8 +727,6 @@ def test_simplified_complete_implementation_no_docstring() -> None:
except Exception as e:
class ResultHandler:
def __init__(self):
self.error = str(e)
def target_method(self, key: str) -> None:
raise RuntimeError(f"Failed to initialize: {self.error}")
@ -767,12 +734,6 @@ def test_simplified_complete_implementation_no_docstring() -> None:
expected = """
class DataProcessor:
def __init__(self, data: Dict[str, Any]) -> None:
self.data = data
self._processed = False
self.result = None
def __repr__(self) -> str:
return f"DataProcessor(processed={self._processed})"
@ -781,17 +742,12 @@ def test_simplified_complete_implementation_no_docstring() -> None:
processor = DataProcessor(sample_data)
class ResultHandler:
def __init__(self, processor: DataProcessor):
self.processor = processor
self.cache = {}
def __str__(self) -> str:
return f"ResultHandler(cache_size={len(self.cache)})"
except Exception as e:
class ResultHandler:
def __init__(self):
self.error = str(e)
pass
"""
output = get_read_only_code(

View file

@ -161,7 +161,7 @@ def test_try_except_structure() -> None:
assert result.strip() == expected.strip()
def test_dunder_method() -> None:
def test_init_method() -> None:
code = """
class MyClass:
def __init__(self):
@ -175,6 +175,30 @@ def test_dunder_method() -> None:
"""
result = get_read_writable_code(dedent(code), {"MyClass.target_method"})
expected = dedent("""
class MyClass:
def __init__(self):
self.x = 1
def target_method(self):
return f"Value: {self.x}"
""")
assert result.strip() == expected.strip()
def test_dunder_method() -> None:
code = """
class MyClass:
def __repr__(self):
return "MyClass"
def other_method(self):
return "other"
def target_method(self):
return f"Value: {self.x}"
"""
result = get_read_writable_code(dedent(code), {"MyClass.target_method"})
expected = dedent("""
class MyClass:
@ -183,7 +207,6 @@ def test_dunder_method() -> None:
""")
assert result.strip() == expected.strip()
def test_no_targets_found() -> None:
code = """
class MyClass:

View file

@ -0,0 +1,235 @@
from textwrap import dedent
import libcst as cst
from codeflash.code_utils.code_replacer import merge_init_functions, replace_functions_in_file
from codeflash.models.models import FunctionParent
def test_basic_merge() -> None:
original = """
class MyClass:
def __init__(self, a, b):
self.a = a
self.b = b
"""
new = """
class MyClass:
def __init__(self, x, y):
self.x = x
self.y = y
"""
result = merge_init_functions(
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
)
expected = """
def __init__(self, a, b):
self.a = a
self.b = b
self.x = x
self.y = y
"""
assert cst.Module([result]).code.strip() == dedent(expected).strip()
def test_prevent_duplication() -> None:
original = """
class MyClass:
def __init__(self):
self.x = 1
print("init")
self.setup()
"""
new = """
class MyClass:
def __init__(self):
print("init")
self.setup()
self.y = 2
"""
result = merge_init_functions(
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
)
expected = """
def __init__(self):
self.x = 1
print("init")
self.setup()
self.y = 2
"""
assert cst.Module([result]).code.strip() == dedent(expected).strip()
def test_prevent_overwrite() -> None:
original = """
class MyClass:
def __init__(self):
self.x = 1
self.y = 2
"""
new = """
class MyClass:
def __init__(self):
self.x = 2
"""
result = merge_init_functions(
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
)
expected = """
def __init__(self):
self.x = 1
self.y = 2
"""
assert cst.Module([result]).code.strip() == dedent(expected).strip()
def test_complex_control_flow() -> None:
original = """
class MyClass:
def __init__(self):
with self.lock:
self.setup()
if self.debug:
self.enable_logging()
"""
new = """
class MyClass:
def __init__(self):
try:
self.connect()
except ConnectionError:
self.fallback()
"""
result = merge_init_functions(
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
)
expected = """
def __init__(self):
with self.lock:
self.setup()
if self.debug:
self.enable_logging()
try:
self.connect()
except ConnectionError:
self.fallback()
"""
assert cst.Module([result]).code.strip() == dedent(expected).strip()
def test_docstrings_and_comments() -> None:
original = """
class MyClass:
def __init__(self):
\"\"\"Original docstring.\"\"\"
# Setup configuration
self.config = {} # Empty config
"""
new = """
class MyClass:
def __init__(self):
\"\"\"New docstring.\"\"\"
# Initialize database
self.db = None # Database connection
"""
result = merge_init_functions(
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
)
# TODO: handle docstrings differently
expected = """
def __init__(self):
\"\"\"Original docstring.\"\"\"
# Setup configuration
self.config = {} # Empty config
\"\"\"New docstring.\"\"\"
# Initialize database
self.db = None # Database connection
"""
assert cst.Module([result]).code.strip() == dedent(expected).strip()
def test_type_annotations() -> None:
original = """
class MyClass:
def __init__(self) -> None:
self.x: int = 1
self.y: str = "hello"
"""
new = """
class MyClass:
def __init__(self):
self.y: str = "new hello"
self.z: float = 2.0
"""
result = merge_init_functions(
cst.parse_module(dedent(original)).body[0].body.body[0], cst.parse_module(dedent(new)).body[0].body.body[0]
)
expected = """
def __init__(self) -> None:
self.x: int = 1
self.y: str = "hello"
self.z: float = 2.0
"""
assert cst.Module([result]).code.strip() == dedent(expected).strip()
# Tests for code replacement with init
def test_merge_init_methods() -> None:
optim_code = """class MyClass:
def __init__(self):
self.y = 2
self.z = 3
"""
original_code = """class MyClass:
def __init__(self):
self.y = 1
self.setup()
"""
expected = """class MyClass:
def __init__(self):
self.y = 1
self.setup()
self.z = 3
"""
result = replace_functions_in_file(
source_code=original_code,
original_function_names=[],
optimized_code=optim_code,
preexisting_objects=[("__init__", [FunctionParent(name="MyClass", type="ClassDef")])],
)
assert result == expected
def test_init_is_function_to_optimize() -> None:
optim_code = """class MyClass:
def __init__(self):
self.y = 2
self.z = 3
"""
original_code = """class MyClass:
def __init__(self):
self.y = 1
self.setup()
"""
expected = """class MyClass:
def __init__(self):
self.y = 2
self.z = 3
"""
# In this scenario, we leave the mutation check to the usual FTO behaviour check.
result = replace_functions_in_file(
source_code=original_code,
original_function_names=["MyClass.__init__"],
optimized_code=optim_code,
preexisting_objects=[("__init__", [FunctionParent(name="MyClass", type="ClassDef")])],
)
assert result == expected