mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
First part of a refactor. cst_context and retriever implements logic to retrieve helper functions and build up the necessary CST trees/code as read_write and read_only context. In this commit is also a bunch of test files.
This commit is contained in:
parent
91ecf113aa
commit
8b3ce0e9b9
17 changed files with 2380 additions and 11 deletions
0
code_to_optimize/code_directories/retriever/__init__.py
Normal file
0
code_to_optimize/code_directories/retriever/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from bubble_sort_with_math import sorter
|
||||
|
||||
|
||||
def sort_from_another_file(arr):
|
||||
sorted_arr = sorter(arr)
|
||||
return sorted_arr
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
import math
|
||||
|
||||
|
||||
def sorter(arr):
|
||||
arr.sort()
|
||||
x = math.sqrt(2)
|
||||
print(x)
|
||||
return arr
|
||||
2
code_to_optimize/code_directories/retriever/globals.py
Normal file
2
code_to_optimize/code_directories/retriever/globals.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
# Define a global variable
|
||||
API_URL = "https://api.example.com/data"
|
||||
23
code_to_optimize/code_directories/retriever/main.py
Normal file
23
code_to_optimize/code_directories/retriever/main.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import requests # Third-party library
|
||||
from globals import API_URL # Global variable defined in another file
|
||||
from utils import DataProcessor
|
||||
|
||||
|
||||
def fetch_and_process_data():
|
||||
# Use the global variable for the request
|
||||
response = requests.get(API_URL)
|
||||
response.raise_for_status()
|
||||
|
||||
raw_data = response.text
|
||||
|
||||
# Use code from another file (utils.py)
|
||||
processor = DataProcessor()
|
||||
processed = processor.process_data(raw_data)
|
||||
processed = processor.add_prefix(processed)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = fetch_and_process_data()
|
||||
print("Processed data:", result)
|
||||
27
code_to_optimize/code_directories/retriever/utils.py
Normal file
27
code_to_optimize/code_directories/retriever/utils.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import math
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
"""A class for processing data."""
|
||||
|
||||
number = 1
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
"""Initialize the DataProcessor with a default prefix."""
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a string representation of the DataProcessor."""
|
||||
return f"DataProcessor(default_prefix={self.default_prefix!r})"
|
||||
|
||||
def process_data(self, raw_data: str) -> str:
|
||||
"""Process raw data by converting it to uppercase."""
|
||||
return raw_data.upper()
|
||||
|
||||
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
|
||||
"""Add a prefix to the processed data."""
|
||||
return prefix + data
|
||||
|
||||
def do_something(self):
|
||||
print("something")
|
||||
|
|
@ -74,6 +74,7 @@ class AiServiceClient:
|
|||
def optimize_python_code(
|
||||
self,
|
||||
source_code: str,
|
||||
dependency_code: str,
|
||||
trace_id: str,
|
||||
num_candidates: int = 10,
|
||||
experiment_metadata: ExperimentMetadata | None = None,
|
||||
|
|
@ -83,15 +84,20 @@ class AiServiceClient:
|
|||
Parameters
|
||||
----------
|
||||
- source_code (str): The python code to optimize.
|
||||
- num_variants (int): Number of optimization variants to generate. Default is 10.
|
||||
- read_write_context (str) : The python code to be used as read-write context.
|
||||
- read_only_context (str): The python code to be used as read-only context.
|
||||
- trace_id (str): Trace id of optimization run
|
||||
- num_candidates (int): Number of optimization variants to generate. Default is 10.
|
||||
- experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization
|
||||
|
||||
Returns
|
||||
-------
|
||||
- List[Optimization]: A list of Optimization objects.
|
||||
- List[OptimizationCandidate]: A list of Optimization Candidates.
|
||||
|
||||
"""
|
||||
payload = {
|
||||
"source_code": source_code,
|
||||
"dependency_code": dependency_code,
|
||||
"num_variants": num_candidates,
|
||||
"trace_id": trace_id,
|
||||
"python_version": platform.python_version(),
|
||||
|
|
|
|||
|
|
@ -104,6 +104,72 @@ def add_needed_imports_from_module(
|
|||
return dst_module_code
|
||||
|
||||
|
||||
def add_needed_imports_from_module_2(
|
||||
src_module_code: str,
|
||||
dst_module_code: str,
|
||||
src_path: Path,
|
||||
dst_path: Path,
|
||||
project_root: Path,
|
||||
helper_functions_fully_qualified_names: list[str] | None = None,
|
||||
) -> str:
|
||||
"""Copy of add_needed_imports_from_module. will remove in a future refactor. This function simply changes the 'helper_functions' argument"""
|
||||
src_module_code = delete___future___aliased_imports(src_module_code)
|
||||
if helper_functions_fully_qualified_names is None:
|
||||
helper_functions_fully_qualified_names = []
|
||||
|
||||
src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path)
|
||||
dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path)
|
||||
|
||||
dst_context: CodemodContext = CodemodContext(
|
||||
filename=src_path.name,
|
||||
full_module_name=dst_module_and_package.name,
|
||||
full_package_name=dst_module_and_package.package,
|
||||
)
|
||||
gatherer: GatherImportsVisitor = GatherImportsVisitor(
|
||||
CodemodContext(
|
||||
filename=src_path.name,
|
||||
full_module_name=src_module_and_package.name,
|
||||
full_package_name=src_module_and_package.package,
|
||||
)
|
||||
)
|
||||
cst.parse_module(src_module_code).visit(gatherer)
|
||||
try:
|
||||
for mod in gatherer.module_imports:
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
|
||||
for mod, obj_seq in gatherer.object_mapping.items():
|
||||
for obj in obj_seq:
|
||||
if f"{mod}.{obj}" in helper_functions_fully_qualified_names:
|
||||
continue # Skip adding imports for helper functions already in the context
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error adding imports to destination module code: {e}")
|
||||
return dst_module_code
|
||||
for mod, asname in gatherer.module_aliases.items():
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
|
||||
for mod, alias_pairs in gatherer.alias_mapping.items():
|
||||
for alias_pair in alias_pairs:
|
||||
if f"{mod}.{alias_pair[0]}" in helper_functions_fully_qualified_names:
|
||||
continue
|
||||
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
|
||||
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
|
||||
|
||||
try:
|
||||
parsed_module = cst.parse_module(dst_module_code)
|
||||
except cst.ParserSyntaxError as e:
|
||||
logger.exception(f"Syntax error in destination module code: {e}")
|
||||
return dst_module_code # Return the original code if there's a syntax error
|
||||
try:
|
||||
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module)
|
||||
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
|
||||
return transformed_module.code.lstrip("\n")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error adding imports to destination module code: {e}")
|
||||
return dst_module_code
|
||||
|
||||
|
||||
def get_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | None, set[tuple[str, str]]]:
|
||||
"""Return the code for a function or methods in a Python module. functions_to_optimize is either a singleton
|
||||
FunctionToOptimize instance, which represents either a function at the module level or a method of a class at the
|
||||
|
|
|
|||
|
|
@ -54,6 +54,8 @@ class BestOptimization(BaseModel):
|
|||
|
||||
class CodeOptimizationContext(BaseModel):
|
||||
code_to_optimize_with_helpers: str
|
||||
read_write_context_code: str = ""
|
||||
read_only_context_code: str = ""
|
||||
contextual_dunder_methods: set[tuple[str, str]]
|
||||
helper_functions: list[FunctionSource]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]]
|
||||
|
|
|
|||
260
codeflash/optimization/cst_context.py
Normal file
260
codeflash/optimization/cst_context.py
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import libcst as cst
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CSTContextNode(BaseModel):
|
||||
"""A node in the context tree, representing a single node in the CST. This context tree is a part of the main CST tree that contains nodes that lead to a target function.
|
||||
The corresponding cst_node is stored so that the cst can be rebuilt flexibly, based on whatever information is needed.
|
||||
In the future, this tree can be used as the code replacer, since we know where the target functions are located in the tree.
|
||||
"""
|
||||
|
||||
cst_node: cst.CSTNode | None
|
||||
children: dict[str, list[CSTContextNode] | CSTContextNode] = Field(default_factory=dict)
|
||||
is_target_function: bool = False
|
||||
target_functions: set[str] = Field(default_factory=set)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def add_child(self, section: str, child: CSTContextNode):
|
||||
"""Add a child node to the specified section. sections are either body, orelse, finalbody, or handlers."""
|
||||
original_section = getattr(self.cst_node, section, None)
|
||||
if isinstance(original_section, (list, tuple)):
|
||||
if section not in self.children:
|
||||
self.children[section] = []
|
||||
self.children[section].append(child)
|
||||
else:
|
||||
self.children[section] = child
|
||||
|
||||
|
||||
def build_context_tree(context_node: CSTContextNode, prefix: str = "") -> bool:
|
||||
"""Recursively builds a context tree from a CST, tracking target functions and their containing structures.
|
||||
|
||||
Args:
|
||||
context_node: Current node in the context tree
|
||||
prefix: Prefix to add to the target function names to create qualified names
|
||||
|
||||
Returns:
|
||||
bool: True if a target function was found in this branch
|
||||
|
||||
"""
|
||||
|
||||
def process_node(node: cst.CSTNode, section: str) -> bool:
|
||||
if isinstance(node, cst.ClassDef):
|
||||
if prefix: # Don't go into nested classes
|
||||
return False
|
||||
class_node = CSTContextNode(cst_node=node, target_functions=context_node.target_functions)
|
||||
if build_context_tree(class_node, f"{prefix}.{node.name.value}" if prefix else node.name.value):
|
||||
context_node.add_child(section, class_node)
|
||||
return True
|
||||
|
||||
elif isinstance(node, cst.FunctionDef):
|
||||
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||
if qualified_name in context_node.target_functions:
|
||||
func_node = CSTContextNode(
|
||||
cst_node=node, is_target_function=True, target_functions=context_node.target_functions
|
||||
)
|
||||
context_node.add_child(section, func_node)
|
||||
return True
|
||||
|
||||
elif isinstance(node, cst.CSTNode):
|
||||
other_node = CSTContextNode(cst_node=node, target_functions=context_node.target_functions)
|
||||
if build_context_tree(other_node, prefix):
|
||||
context_node.add_child(section, other_node)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
has_target = False
|
||||
node = context_node.cst_node
|
||||
|
||||
# Check each section directly
|
||||
for section_name in ["body", "orelse", "finalbody", "handlers"]:
|
||||
section_content = getattr(node, section_name, None)
|
||||
if section_content is not None:
|
||||
if isinstance(section_content, (list, tuple)):
|
||||
has_target_list = [process_node(section_node, section_name) for section_node in section_content]
|
||||
has_target |= any(has_target_list)
|
||||
else:
|
||||
has_target |= process_node(section_content, section_name)
|
||||
|
||||
return has_target
|
||||
|
||||
|
||||
def find_containing_classes(code: str, target_functions: set[str]) -> CSTContextNode:
|
||||
"""Parse the code and find all class definitions containing the target functions."""
|
||||
root = CSTContextNode(cst_node=cst.parse_module(code), target_functions=target_functions)
|
||||
if not build_context_tree(root):
|
||||
raise ValueError("No target functions found in the provided code")
|
||||
return root
|
||||
|
||||
|
||||
def create_read_write_context(context_node: CSTContextNode) -> str:
|
||||
"""Rebuilds a CST tree to create our read-write context"""
|
||||
|
||||
def rebuild_node(node: CSTContextNode) -> cst.CSTNode:
|
||||
if node.is_target_function:
|
||||
return node.cst_node
|
||||
|
||||
updates = {}
|
||||
for section_name, children in node.children.items():
|
||||
if isinstance(children, list):
|
||||
updates[section_name] = [rebuild_node(child) for child in children]
|
||||
else:
|
||||
updates[section_name] = rebuild_node(children)
|
||||
|
||||
return node.cst_node.with_changes(**updates)
|
||||
|
||||
if not isinstance(context_node.cst_node, cst.Module):
|
||||
raise ValueError("Root of context tree must be a Module node")
|
||||
|
||||
rebuilt_module = rebuild_node(context_node)
|
||||
return rebuilt_module.code
|
||||
|
||||
|
||||
def create_read_only_context(context_node: CSTContextNode) -> str:
|
||||
"""Creates a read-only version of the context tree where:
|
||||
- Global variables and typing information are preserved
|
||||
- Class definitions preserve all information (variables, docstrings, etc.)
|
||||
- Only dunder methods are preserved
|
||||
- All other methods (including target functions) are removed completely
|
||||
"""
|
||||
|
||||
def rebuild_non_context_node(node: cst.CSTNode) -> cst.CSTNode | None:
|
||||
"""Recursively rebuild non-context nodes, preserving all statements except class and function definitions."""
|
||||
if isinstance(node, (cst.ClassDef, cst.FunctionDef, cst.Import, cst.ImportFrom)):
|
||||
return None
|
||||
|
||||
updates = {}
|
||||
for section_name in ["body", "orelse", "finalbody", "handlers"]:
|
||||
section_content = getattr(node, section_name, None)
|
||||
if section_content is not None:
|
||||
if isinstance(section_content, (list, tuple)):
|
||||
rebuilt_children = []
|
||||
for child in section_content:
|
||||
rebuilt_child = rebuild_non_context_node(child)
|
||||
if rebuilt_child is not None:
|
||||
rebuilt_children.append(rebuilt_child)
|
||||
if rebuilt_children:
|
||||
updates[section_name] = rebuilt_children
|
||||
else:
|
||||
rebuilt_child = rebuild_non_context_node(section_content)
|
||||
if rebuilt_child is not None:
|
||||
updates[section_name] = rebuilt_child
|
||||
|
||||
if not updates:
|
||||
return node
|
||||
|
||||
return node.with_changes(**updates)
|
||||
|
||||
def rebuild_node(node: CSTContextNode) -> cst.CSTNode | None:
|
||||
# Remove target functions completely
|
||||
if node.is_target_function:
|
||||
return None
|
||||
|
||||
# For regular functions, remove them
|
||||
if isinstance(node.cst_node, cst.FunctionDef):
|
||||
return None
|
||||
|
||||
# For class definitions, preserve structure but remove non-dunder methods
|
||||
if isinstance(node.cst_node, cst.ClassDef):
|
||||
return rebuild_class(node)
|
||||
|
||||
updates = {}
|
||||
for section_name in ["body", "orelse", "finalbody", "handlers"]:
|
||||
section_content = getattr(node.cst_node, section_name, None)
|
||||
if section_content is not None:
|
||||
context_children = node.children.get(section_name, [])
|
||||
context_children = [context_children] if not isinstance(context_children, list) else context_children
|
||||
|
||||
if isinstance(section_content, (list, tuple)):
|
||||
rebuilt_children = []
|
||||
# Create a map of context children positions
|
||||
context_positions = {id(child.cst_node): i for i, child in enumerate(context_children)}
|
||||
|
||||
for i, child_node in enumerate(section_content):
|
||||
# If this node is in context, use rebuild_node
|
||||
if id(child_node) in context_positions:
|
||||
context_idx = context_positions[id(child_node)]
|
||||
rebuilt_child = rebuild_node(context_children[context_idx])
|
||||
else:
|
||||
# Otherwise use rebuild_non_context_node
|
||||
rebuilt_child = rebuild_non_context_node(child_node)
|
||||
|
||||
if rebuilt_child is not None:
|
||||
rebuilt_children.append(rebuilt_child)
|
||||
|
||||
if rebuilt_children:
|
||||
updates[section_name] = rebuilt_children
|
||||
else:
|
||||
# Single node case
|
||||
if context_children: # If we have a context child
|
||||
rebuilt_child = rebuild_node(context_children[0])
|
||||
else:
|
||||
rebuilt_child = rebuild_non_context_node(section_content)
|
||||
|
||||
if rebuilt_child is not None:
|
||||
updates[section_name] = rebuilt_child
|
||||
|
||||
if not updates:
|
||||
return None
|
||||
|
||||
return node.cst_node.with_changes(**updates)
|
||||
|
||||
def is_dunder_method(func_node: cst.FunctionDef) -> bool:
|
||||
"""Check if a function is a dunder method."""
|
||||
name = func_node.name.value
|
||||
return name.startswith("__") and name.endswith("__")
|
||||
|
||||
def rebuild_class(node: CSTContextNode) -> cst.ClassDef:
|
||||
"""Rebuilds a class definition, preserving only structure, variables, and dunder methods."""
|
||||
class_node = node.cst_node
|
||||
new_body = []
|
||||
|
||||
if isinstance(class_node.body, cst.IndentedBlock):
|
||||
body_statements = class_node.body.body
|
||||
else:
|
||||
body_statements = [class_node.body]
|
||||
|
||||
for stmt in body_statements:
|
||||
if isinstance(stmt, cst.FunctionDef):
|
||||
if is_dunder_method(stmt):
|
||||
if f"{class_node.name.value}.{stmt.name.value}" in node.target_functions:
|
||||
# target function is a dunder method already shown in read-write context
|
||||
continue
|
||||
# Preserve only dunder methods
|
||||
new_body.append(stmt)
|
||||
else:
|
||||
# Keep all other class contents (variables, docstrings, etc.)
|
||||
new_body.append(stmt)
|
||||
|
||||
if not new_body:
|
||||
return None
|
||||
return class_node.with_changes(body=cst.IndentedBlock(new_body))
|
||||
|
||||
if not isinstance(context_node.cst_node, cst.Module):
|
||||
raise ValueError("Root of context tree must be a Module node")
|
||||
|
||||
rebuilt_module = rebuild_node(context_node)
|
||||
if rebuilt_module is None:
|
||||
return ""
|
||||
return rebuilt_module.code
|
||||
|
||||
|
||||
def print_tree(node: CSTContextNode, level: int = 0):
|
||||
"""Helper function to visualize the full CST node structure recursively"""
|
||||
indent = " " * level
|
||||
print(f"\n{indent}CSTContextNode:")
|
||||
print(f"{indent} is_target_function: {node.is_target_function}")
|
||||
print(f"{indent} cst_node type: {type(node.cst_node)}")
|
||||
print(f"{indent} children:")
|
||||
|
||||
for section_name, children in node.children.items():
|
||||
print(f"{indent} {section_name}:")
|
||||
if isinstance(children, list):
|
||||
for child in children:
|
||||
print_tree(child, level + 3)
|
||||
else:
|
||||
print_tree(children, level + 3)
|
||||
|
|
@ -78,6 +78,8 @@ if TYPE_CHECKING:
|
|||
|
||||
from codeflash.models.models import CoverageData, FunctionCalledInTest, FunctionSource, OptimizedCandidate
|
||||
|
||||
from codeflash.optimization import retriever
|
||||
|
||||
|
||||
class Optimizer:
|
||||
def __init__(self, args: Namespace) -> None:
|
||||
|
|
@ -272,13 +274,15 @@ class Optimizer:
|
|||
f"Generating new tests and optimizations for function {function_to_optimize.function_name}", transient=True
|
||||
):
|
||||
generated_results = self.generate_tests_and_optimizations(
|
||||
code_context.code_to_optimize_with_helpers,
|
||||
function_to_optimize,
|
||||
code_context.helper_functions,
|
||||
Path(original_module_path),
|
||||
function_trace_id,
|
||||
generated_test_paths,
|
||||
function_to_optimize_ast,
|
||||
code_to_optimize_with_helpers=code_context.code_to_optimize_with_helpers,
|
||||
read_write_context_code=code_context.read_write_context_code,
|
||||
read_only_context_code=code_context.read_only_context_code,
|
||||
function_to_optimize=function_to_optimize,
|
||||
helper_functions=code_context.helper_functions,
|
||||
module_path=Path(original_module_path),
|
||||
function_trace_id=function_trace_id,
|
||||
generated_test_paths=generated_test_paths,
|
||||
function_to_optimize_ast=function_to_optimize_ast,
|
||||
run_experiment=should_run_experiment,
|
||||
)
|
||||
|
||||
|
|
@ -709,9 +713,16 @@ class Optimizer:
|
|||
)
|
||||
preexisting_objects = find_preexisting_objects(code_to_optimize_with_helpers)
|
||||
contextual_dunder_methods.update(helper_dunder_methods)
|
||||
|
||||
# Will eventually refactor to use this function instead of the above
|
||||
read_write_context, read_only_context = retriever.get_code_optimization_context(
|
||||
function_to_optimize, project_root, original_source_code
|
||||
)
|
||||
return Success(
|
||||
CodeOptimizationContext(
|
||||
code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,
|
||||
read_write_context_code=read_write_context,
|
||||
read_only_context_code=read_only_context,
|
||||
contextual_dunder_methods=contextual_dunder_methods,
|
||||
helper_functions=helper_functions,
|
||||
preexisting_objects=preexisting_objects,
|
||||
|
|
@ -795,6 +806,8 @@ class Optimizer:
|
|||
def generate_tests_and_optimizations(
|
||||
self,
|
||||
code_to_optimize_with_helpers: str,
|
||||
read_write_context_code: str,
|
||||
read_only_context_code: str,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
helper_functions: list[FunctionSource],
|
||||
module_path: Path,
|
||||
|
|
@ -819,7 +832,8 @@ class Optimizer:
|
|||
)
|
||||
future_optimization_candidates = executor.submit(
|
||||
self.aiservice_client.optimize_python_code,
|
||||
code_to_optimize_with_helpers,
|
||||
read_write_context_code,
|
||||
read_only_context_code,
|
||||
function_trace_id[:-4] + "EXP0" if run_experiment else function_trace_id,
|
||||
N_CANDIDATES,
|
||||
ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None,
|
||||
|
|
@ -833,7 +847,8 @@ class Optimizer:
|
|||
if run_experiment:
|
||||
future_candidates_exp = executor.submit(
|
||||
self.local_aiservice_client.optimize_python_code,
|
||||
code_to_optimize_with_helpers,
|
||||
read_write_context_code,
|
||||
read_only_context_code,
|
||||
function_trace_id[:-4] + "EXP1",
|
||||
N_CANDIDATES,
|
||||
ExperimentMetadata(id=self.experiment_id, group="experiment"),
|
||||
|
|
|
|||
115
codeflash/optimization/retriever.py
Normal file
115
codeflash/optimization/retriever.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import jedi
|
||||
import libcst as cst
|
||||
from jedi.api.classes import Name
|
||||
from returns.result import Result
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module_2
|
||||
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
from codeflash.optimization.cst_context import (
|
||||
CSTContextNode,
|
||||
build_context_tree,
|
||||
create_read_only_context,
|
||||
create_read_write_context,
|
||||
)
|
||||
from codeflash.optimization.function_context import belongs_to_class, belongs_to_function
|
||||
|
||||
|
||||
def get_code_optimization_context(
|
||||
function_to_optimize: FunctionToOptimize, project_root_path: Path, original_source_code: str
|
||||
) -> Result[CodeOptimizationContext, str]:
|
||||
function_name = function_to_optimize.function_name
|
||||
file_path = function_to_optimize.file_path
|
||||
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
|
||||
file_path_to_qualified_function_names = defaultdict(set)
|
||||
file_path_to_qualified_function_names[file_path].add(function_to_optimize.qualified_name)
|
||||
read_write_list = []
|
||||
read_only_list = []
|
||||
read_write_string = ""
|
||||
read_only_string = ""
|
||||
names = []
|
||||
for ref in script.get_names(all_scopes=True, definitions=False, references=True):
|
||||
if ref.full_name:
|
||||
if function_to_optimize.parents:
|
||||
# Check if the reference belongs to the specified class when FunctionParent is provided
|
||||
if belongs_to_class(ref, function_to_optimize.parents[-1].name) and belongs_to_function(
|
||||
ref, function_name
|
||||
):
|
||||
names.append(ref)
|
||||
elif belongs_to_function(ref, function_name):
|
||||
names.append(ref)
|
||||
|
||||
for name in names:
|
||||
try:
|
||||
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
|
||||
except Exception as e:
|
||||
try:
|
||||
logger.exception(f"Error while getting definition for {name.full_name}: {e}")
|
||||
except Exception as e:
|
||||
# name.full_name can also throw exceptions sometimes
|
||||
logger.exception(f"Error while getting definition: {e}")
|
||||
definitions = []
|
||||
if definitions:
|
||||
# TODO: there can be multiple definitions, see how to handle such cases
|
||||
definition = definitions[0]
|
||||
definition_path = definition.module_path
|
||||
|
||||
# The definition is part of this project and not defined within the original function
|
||||
if (
|
||||
str(definition_path).startswith(str(project_root_path) + os.sep)
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and definition.full_name
|
||||
and not belongs_to_function(definition, function_name)
|
||||
):
|
||||
file_path_to_qualified_function_names[definition_path].add(
|
||||
get_qualified_name(definition.module_name, definition.full_name)
|
||||
)
|
||||
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
|
||||
try:
|
||||
og_code_containing_helpers = file_path.read_text("utf8")
|
||||
context_tree_root = cst.parse_module(og_code_containing_helpers)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error while parsing {file_path}: {e}")
|
||||
continue
|
||||
context_tree_root = CSTContextNode(cst_node=context_tree_root, target_functions=qualified_function_names)
|
||||
if not build_context_tree(context_tree_root, ""):
|
||||
logger.debug(
|
||||
f"{qualified_function_names} was not found in {file_path} when retrieving code optimization context"
|
||||
)
|
||||
continue
|
||||
|
||||
read_write_context_string = f"{create_read_write_context(context_tree_root)}"
|
||||
if read_write_context_string:
|
||||
read_write_string += f"\n{read_write_context_string}"
|
||||
read_write_string = add_needed_imports_from_module_2(
|
||||
og_code_containing_helpers,
|
||||
read_write_string,
|
||||
file_path,
|
||||
file_path,
|
||||
project_root_path,
|
||||
list(qualified_function_names),
|
||||
)
|
||||
|
||||
read_only_context_string = f"{create_read_only_context(context_tree_root)}\n"
|
||||
read_only_context_string_with_imports = add_needed_imports_from_module_2(
|
||||
og_code_containing_helpers,
|
||||
read_only_context_string,
|
||||
file_path,
|
||||
file_path,
|
||||
project_root_path,
|
||||
list(qualified_function_names),
|
||||
)
|
||||
if read_only_context_string_with_imports:
|
||||
read_only_list.append(f"```python:{file_path}\n{read_only_context_string_with_imports}```")
|
||||
|
||||
read_only_string = "\n".join(read_only_list)
|
||||
print("read_write_string: \n\n", read_write_string)
|
||||
|
||||
print("read_only_string: \n\n", read_only_string)
|
||||
return read_write_string, read_only_string
|
||||
261
tests/test_build_context_tree.py
Normal file
261
tests/test_build_context_tree.py
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
from textwrap import dedent
|
||||
|
||||
import libcst as cst
|
||||
from codeflash.optimization.cst_context import find_containing_classes, print_tree
|
||||
|
||||
|
||||
def test_basic_function_identification():
|
||||
"""Test that target functions are correctly identified."""
|
||||
code = """
|
||||
def target_func():
|
||||
pass
|
||||
|
||||
def other_func():
|
||||
pass
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"target_func"})
|
||||
# Should find exactly one function
|
||||
assert len(result.children["body"]) == 1
|
||||
assert result.children["body"][0].is_target_function
|
||||
assert isinstance(result.children["body"][0].cst_node, cst.FunctionDef)
|
||||
assert result.children["body"][0].cst_node.name.value == "target_func"
|
||||
|
||||
|
||||
def test_class_with_target_function():
|
||||
"""Test that classes containing target functions are included with only the class definition."""
|
||||
code = """
|
||||
class TestClass:
|
||||
def target_method(self):
|
||||
pass
|
||||
|
||||
def other_method(self):
|
||||
pass
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
|
||||
print_tree(result)
|
||||
assert len(result.children["body"]) == 1
|
||||
class_node = result.children["body"][0]
|
||||
assert isinstance(class_node.cst_node, cst.ClassDef)
|
||||
assert class_node.cst_node.name.value == "TestClass"
|
||||
|
||||
assert len(class_node.children["body"]) == 1
|
||||
assert class_node.children["body"][0].is_target_function
|
||||
assert class_node.children["body"][0].cst_node.name.value == "target_method"
|
||||
|
||||
|
||||
def test_class_without_target_function():
|
||||
"""Test that classes without target functions are not included."""
|
||||
code = """
|
||||
class TestClass:
|
||||
def method1(self):
|
||||
pass
|
||||
|
||||
def method2(self):
|
||||
pass
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"other_func"})
|
||||
|
||||
# Should have no children since no target functions were found
|
||||
assert "body" not in result.children or not result.children["body"]
|
||||
|
||||
|
||||
def test_control_flow_structures():
|
||||
"""Test handling of control flow structures containing target functions."""
|
||||
code = """
|
||||
if True:
|
||||
def target_func():
|
||||
pass
|
||||
else:
|
||||
def other_func():
|
||||
pass
|
||||
|
||||
try:
|
||||
def another_target():
|
||||
pass
|
||||
except Exception:
|
||||
def handler_func():
|
||||
pass
|
||||
finally:
|
||||
def cleanup_func():
|
||||
pass
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"target_func", "another_target"})
|
||||
|
||||
assert len(result.children["body"]) == 2
|
||||
if_context_node = result.children["body"][0]
|
||||
assert isinstance(if_context_node.cst_node, cst.If)
|
||||
target_context_node = if_context_node.children["body"][0]
|
||||
assert target_context_node.cst_node.name.value == "target_func"
|
||||
assert target_context_node.is_target_function
|
||||
|
||||
try_context_node = result.children["body"][1]
|
||||
assert isinstance(try_context_node.cst_node, cst.Try)
|
||||
another_target_context_node = try_context_node.children["body"][0]
|
||||
assert another_target_context_node.cst_node.name.value == "another_target"
|
||||
assert another_target_context_node.is_target_function
|
||||
|
||||
|
||||
def test_nested_classes():
|
||||
"""Test handling of nested classes with target functions."""
|
||||
code = """
|
||||
class OuterClass:
|
||||
class InnerClass:
|
||||
def target_method(self):
|
||||
pass
|
||||
|
||||
def other_method(self):
|
||||
pass
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"OuterClass.InnerClass.target_method"})
|
||||
|
||||
# Verify the class hierarchy
|
||||
assert len(result.children["body"]) == 1
|
||||
outer_class = result.children["body"][0]
|
||||
assert isinstance(outer_class.cst_node, cst.ClassDef)
|
||||
assert outer_class.cst_node.name.value == "OuterClass"
|
||||
|
||||
assert len(outer_class.children["body"]) == 1
|
||||
inner_class = outer_class.children["body"][0]
|
||||
assert isinstance(inner_class.cst_node, cst.ClassDef)
|
||||
assert inner_class.cst_node.name.value == "InnerClass"
|
||||
|
||||
assert len(inner_class.children["body"]) == 1
|
||||
target_method = inner_class.children["body"][0]
|
||||
assert target_method.is_target_function
|
||||
assert target_method.cst_node.name.value == "target_method"
|
||||
|
||||
|
||||
def test_no_classes():
|
||||
"""Test handling of target functions without any classes."""
|
||||
code = """
|
||||
def function1():
|
||||
pass
|
||||
|
||||
def target_function():
|
||||
def nested_function():
|
||||
pass
|
||||
return nested_function
|
||||
|
||||
def function2():
|
||||
pass
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"target_function"})
|
||||
|
||||
# Should find only the target function
|
||||
assert len(result.children["body"]) == 1
|
||||
assert result.children["body"][0].is_target_function
|
||||
assert result.children["body"][0].cst_node.name.value == "target_function"
|
||||
|
||||
|
||||
def test_no_classes_if_else():
|
||||
"""Test handling of target functions in if/else blocks."""
|
||||
code = """
|
||||
def function1():
|
||||
pass
|
||||
if x:
|
||||
def target_function():
|
||||
return "hello"
|
||||
else:
|
||||
def function2():
|
||||
pass
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"target_function"})
|
||||
|
||||
assert result.children["body"][0].children["body"][0].is_target_function
|
||||
assert result.children["body"][0].children["body"][0].cst_node.name.value == "target_function"
|
||||
|
||||
|
||||
def test_no_classes_else():
|
||||
"""Test handling of target functions in else blocks."""
|
||||
code = """
|
||||
def function1():
|
||||
pass
|
||||
if x:
|
||||
x += 2
|
||||
else:
|
||||
def target_function():
|
||||
return "hello"
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"target_function"})
|
||||
|
||||
assert result.children["body"][0].children["orelse"][0].is_target_function
|
||||
assert result.children["body"][0].children["orelse"][0].cst_node.name.value == "target_function"
|
||||
|
||||
|
||||
def test_comments_and_decorators():
|
||||
"""Test that comments and decorators are preserved."""
|
||||
code = """
|
||||
# Top level comment
|
||||
@decorator
|
||||
class TestClass:
|
||||
# Class comment
|
||||
@method_decorator
|
||||
def target_method(self):
|
||||
# Method comment
|
||||
pass
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
|
||||
|
||||
# Verify class has decorator
|
||||
assert len(result.children["body"]) == 1
|
||||
class_node = result.children["body"][0]
|
||||
assert len(class_node.cst_node.decorators) == 1
|
||||
|
||||
# Verify method has decorator
|
||||
method_node = class_node.children["body"][0]
|
||||
assert len(method_node.cst_node.decorators) == 1
|
||||
|
||||
|
||||
def test_same_name_different_paths():
|
||||
"""Test handling of functions with same name but different qualified paths."""
|
||||
code = """
|
||||
class ClassA:
|
||||
def process(self):
|
||||
pass
|
||||
|
||||
class ClassB:
|
||||
def process(self):
|
||||
pass
|
||||
|
||||
def process():
|
||||
pass
|
||||
|
||||
class Outer:
|
||||
class Inner:
|
||||
def process(self):
|
||||
pass
|
||||
|
||||
def process(self):
|
||||
pass
|
||||
"""
|
||||
|
||||
# Test finding specific instances
|
||||
result = find_containing_classes(dedent(code), {"ClassA.process", "Outer.Inner.process"})
|
||||
|
||||
# Test for just top-level process
|
||||
result_top = find_containing_classes(dedent(code), {"process"})
|
||||
assert len(result_top.children["body"]) == 1
|
||||
assert result_top.children["body"][0].is_target_function
|
||||
assert result_top.children["body"][0].cst_node.name.value == "process"
|
||||
|
||||
# Test for just Outer.process
|
||||
result_outer = find_containing_classes(dedent(code), {"Outer.process"})
|
||||
assert result_outer.children["body"][0].cst_node.name.value == "Outer"
|
||||
assert result_outer.children["body"][0].children["body"][0].cst_node.name.value == "process"
|
||||
|
||||
# Test for just Inner.process
|
||||
result_inner = find_containing_classes(dedent(code), {"Outer.Inner.process"})
|
||||
outer = result_inner.children["body"][0]
|
||||
assert outer.cst_node.name.value == "Outer"
|
||||
inner = outer.children["body"][0]
|
||||
assert inner.cst_node.name.value == "Inner"
|
||||
assert inner.children["body"][0].cst_node.name.value == "process"
|
||||
409
tests/test_context_cst.py
Normal file
409
tests/test_context_cst.py
Normal file
|
|
@ -0,0 +1,409 @@
|
|||
from typing import Set
|
||||
|
||||
import libcst as cst
|
||||
from codeflash.optimization.context_cst import (
|
||||
prune_module, # replace 'yourmodule' with the actual module where prune_module is defined
|
||||
)
|
||||
|
||||
|
||||
def test_top_level_target_function():
|
||||
code = """
|
||||
def foo():
|
||||
pass
|
||||
|
||||
def bar():
|
||||
pass
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
targets: Set[str] = {"bar"}
|
||||
|
||||
pruned = prune_module(module, targets)
|
||||
expected = """
|
||||
def bar():
|
||||
pass
|
||||
"""
|
||||
assert pruned.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_no_targets_found():
|
||||
code = """
|
||||
def foo():
|
||||
pass
|
||||
|
||||
x = 10
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
targets: Set[str] = {"bar"} # 'bar' doesn't exist in code
|
||||
|
||||
pruned = prune_module(module, targets)
|
||||
expected = "" # no targets found, return empty module
|
||||
assert pruned.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_class_with_target_function():
|
||||
code = """
|
||||
class MyClass:
|
||||
def helper(self):
|
||||
pass
|
||||
|
||||
def target_method(self):
|
||||
return 42
|
||||
|
||||
def unrelated():
|
||||
pass
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
targets: Set[str] = {"MyClass.target_method"}
|
||||
|
||||
pruned = prune_module(module, targets)
|
||||
# We expect to keep MyClass and only target_method in it
|
||||
expected = """
|
||||
class MyClass:
|
||||
|
||||
def target_method(self):
|
||||
return 42
|
||||
"""
|
||||
assert pruned.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_nested_class_with_target_function():
|
||||
code = """
|
||||
class Outer:
|
||||
def outer_method(self):
|
||||
pass
|
||||
|
||||
class Inner:
|
||||
def inner_helper(self):
|
||||
pass
|
||||
|
||||
def target_func(self):
|
||||
print("Target")
|
||||
|
||||
def top_level():
|
||||
pass
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
targets: Set[str] = {"Outer.Inner.target_func"}
|
||||
|
||||
pruned = prune_module(module, targets)
|
||||
# We must keep Outer, Inner, and the target_func inside Inner
|
||||
expected = """
|
||||
class Outer:
|
||||
|
||||
class Inner:
|
||||
|
||||
def target_func(self):
|
||||
print("Target")
|
||||
"""
|
||||
assert pruned.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_if_statements_leading_to_target():
|
||||
code = """
|
||||
def foo():
|
||||
if True:
|
||||
def target():
|
||||
return "yes"
|
||||
else:
|
||||
print("nope")
|
||||
|
||||
def bar():
|
||||
pass
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
targets: Set[str] = {"target"}
|
||||
|
||||
pruned = prune_module(module, targets)
|
||||
# We keep foo, because inside its if body there's the target function.
|
||||
# The else is removed as it doesn't lead to a target. bar is removed as well.
|
||||
expected = ""
|
||||
assert pruned.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_class_with_no_target_functions():
|
||||
code = """
|
||||
class A:
|
||||
def no_target(self):
|
||||
pass
|
||||
|
||||
x = 5
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
targets: Set[str] = {"SomeOtherClass.some_func"}
|
||||
|
||||
pruned = prune_module(module, targets)
|
||||
expected = "" # no targets in code, empty result
|
||||
assert pruned.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_class_in_else_block():
|
||||
code = """
|
||||
if y is False:
|
||||
x = 10
|
||||
else:
|
||||
class MyClass:
|
||||
def not_target(self):
|
||||
pass
|
||||
|
||||
def target_func(self):
|
||||
return "found me!"
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
targets: Set[str] = {"MyClass.target_func"}
|
||||
|
||||
pruned = prune_module(module, targets)
|
||||
# Even though MyClass is in the else block, we have a target method inside it.
|
||||
# We expect to keep the wrapper function, the else block, and the class with just the target method.
|
||||
expected = """
|
||||
if y is False:
|
||||
pass
|
||||
else:
|
||||
class MyClass:
|
||||
|
||||
def target_func(self):
|
||||
return "found me!"
|
||||
"""
|
||||
assert pruned.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_class_in_if_block():
|
||||
code = """
|
||||
if y is False:
|
||||
class MyClass:
|
||||
def not_target(self):
|
||||
pass
|
||||
|
||||
def target_func(self):
|
||||
return "found me!"
|
||||
else:
|
||||
x = 10
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
targets: Set[str] = {"MyClass.target_func"}
|
||||
|
||||
pruned = prune_module(module, targets)
|
||||
# Even though MyClass is in the else block, we have a target method inside it.
|
||||
# We expect to keep the wrapper function, the else block, and the class with just the target method.
|
||||
expected = """
|
||||
if y is False:
|
||||
class MyClass:
|
||||
|
||||
def target_func(self):
|
||||
return "found me!"
|
||||
else:
|
||||
pass
|
||||
"""
|
||||
assert pruned.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_functions_same_name_different_scopes():
|
||||
code = """
|
||||
def foo():
|
||||
return "top-level"
|
||||
|
||||
class Outer:
|
||||
def foo():
|
||||
return "in Outer"
|
||||
|
||||
class Another:
|
||||
def foo():
|
||||
return "in Another"
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
# Only match the "Outer.foo" function, not the top-level "foo" or "Another.foo"
|
||||
targets: Set[str] = {"Outer.foo"}
|
||||
|
||||
pruned = prune_module(module, targets)
|
||||
expected = """
|
||||
class Outer:
|
||||
def foo():
|
||||
return "in Outer"
|
||||
"""
|
||||
assert pruned.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_if_elif_else_block():
|
||||
code = """
|
||||
if condition:
|
||||
def not_target():
|
||||
pass
|
||||
elif other_condition:
|
||||
def another_not_target():
|
||||
pass
|
||||
else:
|
||||
def target_in_else():
|
||||
return "Found"
|
||||
|
||||
def unrelated():
|
||||
return "no"
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
targets: Set[str] = {"target_in_else"}
|
||||
|
||||
pruned = prune_module(module, targets)
|
||||
# We keep the whole if-elif-else structure but prune out the non-target branches.
|
||||
expected = """
|
||||
if condition:
|
||||
pass
|
||||
elif other_condition:
|
||||
pass
|
||||
else:
|
||||
def target_in_else():
|
||||
return "Found"
|
||||
"""
|
||||
assert pruned.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_nested_if_in_else_block():
|
||||
code = """
|
||||
if top_level:
|
||||
def no_target():
|
||||
return 1
|
||||
else:
|
||||
if nested_condition:
|
||||
def target_func():
|
||||
return "nested target"
|
||||
else:
|
||||
def another_no():
|
||||
pass
|
||||
|
||||
def outside():
|
||||
pass
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
targets: Set[str] = {"target_func"}
|
||||
|
||||
pruned = prune_module(module, targets)
|
||||
# We keep the top if-else structure, and inside the else, we keep the nested if block with the target.
|
||||
expected = """
|
||||
if top_level:
|
||||
pass
|
||||
else:
|
||||
if nested_condition:
|
||||
def target_func():
|
||||
return "nested target"
|
||||
"""
|
||||
assert pruned.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_class_in_try_block():
|
||||
code = """
|
||||
try:
|
||||
class MyClass:
|
||||
def target_method(self):
|
||||
return "reached!"
|
||||
except ValueError:
|
||||
def not_target():
|
||||
return "no"
|
||||
else:
|
||||
def also_not_target():
|
||||
return "no"
|
||||
finally:
|
||||
def final_not_target():
|
||||
return "no"
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
targets: Set[str] = {"MyClass.target_method"}
|
||||
|
||||
pruned = prune_module(module, targets)
|
||||
# We keep the try block because it contains the target class,
|
||||
# remove except, else, and finally since they don't lead to the target.
|
||||
expected = """
|
||||
try:
|
||||
class MyClass:
|
||||
|
||||
def target_method(self):
|
||||
return "reached!"
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
finally:
|
||||
pass
|
||||
"""
|
||||
assert pruned.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_target_in_except_block():
|
||||
code = """
|
||||
try:
|
||||
def no_target():
|
||||
return "no"
|
||||
except KeyError:
|
||||
def target_func():
|
||||
return "caught target"
|
||||
except ValueError:
|
||||
def also_no():
|
||||
return "no"
|
||||
else:
|
||||
def no_again():
|
||||
return "no"
|
||||
finally:
|
||||
def final_no():
|
||||
return "no"
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
targets: Set[str] = {"target_func"}
|
||||
|
||||
pruned = prune_module(module, targets)
|
||||
# Remove the try, else, and finally bodies that don't lead to the target.
|
||||
# Keep the except KeyError block with the target, and leave other except blocks as pass.
|
||||
expected = """
|
||||
try:
|
||||
pass
|
||||
except KeyError:
|
||||
def target_func():
|
||||
return "caught target"
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
finally:
|
||||
pass
|
||||
"""
|
||||
assert pruned.code.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_target_in_multiple_except_blocks():
|
||||
code = """
|
||||
try:
|
||||
def first_no():
|
||||
return 1
|
||||
except KeyError:
|
||||
def second_no():
|
||||
return 2
|
||||
except ValueError:
|
||||
def target_in_value():
|
||||
return 3
|
||||
except TypeError:
|
||||
def another_no():
|
||||
return 4
|
||||
else:
|
||||
def else_no():
|
||||
return 5
|
||||
finally:
|
||||
def final_no():
|
||||
return 6
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
targets: Set[str] = {"target_in_value"}
|
||||
|
||||
pruned = prune_module(module, targets)
|
||||
# We only keep the try/except/finally structure necessary to reach target_in_value.
|
||||
# The except for ValueError must remain with the target, others become pass.
|
||||
expected = """
|
||||
try:
|
||||
pass
|
||||
except KeyError:
|
||||
pass
|
||||
except ValueError:
|
||||
def target_in_value():
|
||||
return 3
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
finally:
|
||||
pass
|
||||
"""
|
||||
assert pruned.code.strip() == expected.strip()
|
||||
339
tests/test_create_ro_context.py
Normal file
339
tests/test_create_ro_context.py
Normal file
|
|
@ -0,0 +1,339 @@
|
|||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
from codeflash.optimization.cst_context import create_read_only_context, find_containing_classes, print_tree
|
||||
|
||||
|
||||
def test_basic_class():
|
||||
code = """
|
||||
class TestClass:
|
||||
class_var = "value"
|
||||
|
||||
def target_method(self):
|
||||
print("This should be stubbed")
|
||||
|
||||
def other_method(self):
|
||||
print("This too")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
class_var = "value"
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
|
||||
output = create_read_only_context(result)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_dunder_methods():
|
||||
code = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
|
||||
def target_method(self):
|
||||
print("stub me")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
|
||||
output = create_read_only_context(result)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_target_in_nested_class():
|
||||
"""Test that attempting to find a target in a nested class raises an error."""
|
||||
code = """
|
||||
class Outer:
|
||||
outer_var = 1
|
||||
|
||||
class Inner:
|
||||
inner_var = 2
|
||||
|
||||
def target_method(self):
|
||||
print("stub this")
|
||||
"""
|
||||
|
||||
with pytest.raises(ValueError, match="No target functions found in the provided code"):
|
||||
find_containing_classes(dedent(code), {"Outer.Inner.target_method"})
|
||||
|
||||
|
||||
def test_docstrings():
|
||||
code = """
|
||||
class TestClass:
|
||||
\"\"\"Class docstring.\"\"\"
|
||||
|
||||
def target_method(self):
|
||||
\"\"\"Method docstring.\"\"\"
|
||||
print("stub this")
|
||||
|
||||
def other_method(self):
|
||||
\"\"\"Other docstring.\"\"\"
|
||||
print("stub this too")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
\"\"\"Class docstring.\"\"\"
|
||||
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
|
||||
output = create_read_only_context(result)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_method_signatures():
|
||||
code = """
|
||||
class TestClass:
|
||||
@property
|
||||
def target_method(self) -> str:
|
||||
\"\"\"Property docstring.\"\"\"
|
||||
return "value"
|
||||
|
||||
@classmethod
|
||||
def class_method(cls, param: int = 42) -> None:
|
||||
print("stub this")
|
||||
"""
|
||||
|
||||
expected = """"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
|
||||
output = create_read_only_context(result)
|
||||
print(output)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_multiple_top_level_targets():
|
||||
code = """
|
||||
class TestClass:
|
||||
def target1(self):
|
||||
print("stub 1")
|
||||
|
||||
def target2(self):
|
||||
print("stub 2")
|
||||
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"TestClass.target1", "TestClass.target2"})
|
||||
output = create_read_only_context(result)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_class_annotations():
|
||||
code = """
|
||||
class TestClass:
|
||||
var1: int = 42
|
||||
var2: str
|
||||
|
||||
def target_method(self) -> None:
|
||||
self.var2 = "test"
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
var1: int = 42
|
||||
var2: str
|
||||
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
|
||||
output = create_read_only_context(result)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_class_annotations_if():
|
||||
code = """
|
||||
if True:
|
||||
class TestClass:
|
||||
var1: int = 42
|
||||
var2: str
|
||||
|
||||
def target_method(self) -> None:
|
||||
self.var2 = "test"
|
||||
"""
|
||||
|
||||
expected = """
|
||||
if True:
|
||||
class TestClass:
|
||||
var1: int = 42
|
||||
var2: str
|
||||
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
|
||||
output = create_read_only_context(result)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_class_annotations_try():
|
||||
code = """
|
||||
try:
|
||||
class TestClass:
|
||||
var1: int = 42
|
||||
var2: str
|
||||
|
||||
def target_method(self) -> None:
|
||||
self.var2 = "test"
|
||||
except Exception:
|
||||
continue
|
||||
"""
|
||||
|
||||
expected = """
|
||||
try:
|
||||
class TestClass:
|
||||
var1: int = 42
|
||||
var2: str
|
||||
except Exception:
|
||||
continue
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
|
||||
output = create_read_only_context(result)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_class_annotations_else():
|
||||
code = """
|
||||
if x is True:
|
||||
class TestClass:
|
||||
var1: int = 42
|
||||
var2: str
|
||||
|
||||
def wrong_method(self) -> None:
|
||||
print("wrong")
|
||||
else:
|
||||
class TestClass:
|
||||
var1: int = 42
|
||||
var2: str
|
||||
|
||||
def target_method(self) -> None:
|
||||
self.var2 = "test"
|
||||
"""
|
||||
|
||||
expected = """
|
||||
if x is True:
|
||||
class TestClass:
|
||||
var1: int = 42
|
||||
var2: str
|
||||
|
||||
def wrong_method(self) -> None:
|
||||
print("wrong")
|
||||
else:
|
||||
class TestClass:
|
||||
var1: int = 42
|
||||
var2: str
|
||||
|
||||
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
|
||||
print_tree(result)
|
||||
output = create_read_only_context(result)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_top_level_functions():
|
||||
code = """
|
||||
def target_function(self) -> None:
|
||||
self.var2 = "test"
|
||||
|
||||
def some_function():
|
||||
print("wow")
|
||||
"""
|
||||
|
||||
expected = """"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"target_function"})
|
||||
output = create_read_only_context(result)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_module_scope_var():
|
||||
code = """
|
||||
y = 3
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
|
||||
def target_method(self):
|
||||
print("stub me")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
y = 3
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
|
||||
output = create_read_only_context(result)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_module_scope_var():
|
||||
code = """
|
||||
if True:
|
||||
y = 3
|
||||
|
||||
class OtherClass:
|
||||
def this_method(self):
|
||||
print("this method")
|
||||
def __init__(self):
|
||||
self.y = y
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
|
||||
def target_method(self):
|
||||
print("stub me")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
if True:
|
||||
y = 3
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
|
||||
"""
|
||||
|
||||
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
|
||||
output = create_read_only_context(result)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
207
tests/test_create_rw_context.py
Normal file
207
tests/test_create_rw_context.py
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
from codeflash.optimization.cst_context import create_read_write_context, find_containing_classes, print_tree
|
||||
|
||||
|
||||
def test_simple_function():
|
||||
code = """
|
||||
def target_function():
|
||||
x = 1
|
||||
y = 2
|
||||
return x + y
|
||||
"""
|
||||
root = find_containing_classes(dedent(code), {"target_function"})
|
||||
result = create_read_write_context(root)
|
||||
|
||||
expected = dedent("""
|
||||
def target_function():
|
||||
x = 1
|
||||
y = 2
|
||||
return x + y
|
||||
""")
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_class_method():
|
||||
code = """
|
||||
class MyClass:
|
||||
def target_function(self):
|
||||
x = 1
|
||||
y = 2
|
||||
return x + y
|
||||
"""
|
||||
root = find_containing_classes(dedent(code), {"MyClass.target_function"})
|
||||
result = create_read_write_context(root)
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
def target_function(self):
|
||||
x = 1
|
||||
y = 2
|
||||
return x + y
|
||||
""")
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_class_with_attributes():
|
||||
code = """
|
||||
class MyClass:
|
||||
x: int = 1
|
||||
y: str = "hello"
|
||||
|
||||
def target_method(self):
|
||||
return self.x + 42
|
||||
|
||||
def other_method(self):
|
||||
print("this should be excluded")
|
||||
"""
|
||||
root = find_containing_classes(dedent(code), {"MyClass.target_method"})
|
||||
result = create_read_write_context(root)
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
||||
def target_method(self):
|
||||
return self.x + 42
|
||||
""")
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_basic_class_structure():
|
||||
"""Test that nested classes are ignored for target function search."""
|
||||
code = """
|
||||
class Outer:
|
||||
x = 1
|
||||
def target_method(self):
|
||||
return 42
|
||||
|
||||
class Inner:
|
||||
y = 2
|
||||
def not_findable(self):
|
||||
return 42
|
||||
"""
|
||||
root = find_containing_classes(dedent(code), {"Outer.target_method"})
|
||||
result = create_read_write_context(root)
|
||||
|
||||
expected = dedent("""
|
||||
class Outer:
|
||||
def target_method(self):
|
||||
return 42
|
||||
""")
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_top_level_targets():
|
||||
code = """
|
||||
class OuterClass:
|
||||
x = 1
|
||||
def method1(self):
|
||||
return self.x
|
||||
|
||||
def target_function():
|
||||
return 42
|
||||
"""
|
||||
root = find_containing_classes(dedent(code), {"target_function"})
|
||||
result = create_read_write_context(root)
|
||||
|
||||
expected = dedent("""
|
||||
def target_function():
|
||||
return 42
|
||||
""")
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_multiple_top_level_classes():
|
||||
code = """
|
||||
class ClassA:
|
||||
def process(self):
|
||||
return "A"
|
||||
|
||||
class ClassB:
|
||||
def process(self):
|
||||
return "B"
|
||||
|
||||
class ClassC:
|
||||
def process(self):
|
||||
return "C"
|
||||
"""
|
||||
root = find_containing_classes(dedent(code), {"ClassA.process", "ClassC.process"})
|
||||
result = create_read_write_context(root)
|
||||
|
||||
expected = dedent("""
|
||||
class ClassA:
|
||||
def process(self):
|
||||
return "A"
|
||||
|
||||
class ClassC:
|
||||
def process(self):
|
||||
return "C"
|
||||
""")
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_try_except_structure():
|
||||
code = """
|
||||
try:
|
||||
class TargetClass:
|
||||
def target_method(self):
|
||||
return 42
|
||||
except ValueError:
|
||||
class ErrorClass:
|
||||
def handle_error(self):
|
||||
print("error")
|
||||
"""
|
||||
root = find_containing_classes(dedent(code), {"TargetClass.target_method"})
|
||||
print_tree(root)
|
||||
result = create_read_write_context(root)
|
||||
|
||||
expected = dedent("""
|
||||
try:
|
||||
class TargetClass:
|
||||
def target_method(self):
|
||||
return 42
|
||||
except ValueError:
|
||||
class ErrorClass:
|
||||
def handle_error(self):
|
||||
print("error")
|
||||
""")
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_dunder_method():
|
||||
code = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
def other_method(self):
|
||||
return "other"
|
||||
|
||||
def target_method(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
root = find_containing_classes(dedent(code), {"MyClass.target_method"})
|
||||
result = create_read_write_context(root)
|
||||
|
||||
expected = dedent("""
|
||||
class MyClass:
|
||||
|
||||
def target_method(self):
|
||||
return f"Value: {self.x}"
|
||||
""")
|
||||
assert result.strip() == expected.strip()
|
||||
|
||||
|
||||
def test_no_targets_found():
|
||||
code = """
|
||||
class MyClass:
|
||||
def method(self):
|
||||
pass
|
||||
|
||||
class Inner:
|
||||
def target(self):
|
||||
pass
|
||||
"""
|
||||
with pytest.raises(ValueError, match="No target functions found in the provided code"):
|
||||
find_containing_classes(dedent(code), {"MyClass.Inner.target"})
|
||||
623
tests/test_retriever.py
Normal file
623
tests/test_retriever.py
Normal file
|
|
@ -0,0 +1,623 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.optimization.retriever import get_code_optimization_context
|
||||
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def innocent_bystander(self):
|
||||
pass
|
||||
|
||||
def helper_method(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class MainClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def main_method(self):
|
||||
return HelperClass(self.name).helper_method()
|
||||
|
||||
|
||||
class Graph:
|
||||
def __init__(self, vertices):
|
||||
self.graph = defaultdict(list)
|
||||
self.V = vertices # No. of vertices
|
||||
|
||||
def addEdge(self, u, v):
|
||||
self.graph[u].append(v)
|
||||
|
||||
def topologicalSortUtil(self, v, visited, stack):
|
||||
visited[v] = True
|
||||
|
||||
for i in self.graph[v]:
|
||||
if visited[i] == False:
|
||||
self.topologicalSortUtil(i, visited, stack)
|
||||
|
||||
stack.insert(0, v)
|
||||
|
||||
def topologicalSort(self):
|
||||
visited = [False] * self.V
|
||||
stack = []
|
||||
|
||||
for i in range(self.V):
|
||||
if visited[i] == False:
|
||||
self.topologicalSortUtil(i, visited, stack)
|
||||
|
||||
# Print contents of stack
|
||||
return stack
|
||||
|
||||
|
||||
def test_code_replacement10() -> None:
|
||||
file_path = Path(__file__).resolve()
|
||||
|
||||
func_top_optimize = FunctionToOptimize(
|
||||
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
|
||||
)
|
||||
original_code = file_path.read_text()
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize=func_top_optimize, project_root_path=file_path.parent, original_source_code=original_code
|
||||
)
|
||||
|
||||
expected_read_write_context = """
|
||||
from __future__ import annotations
|
||||
|
||||
class HelperClass:
|
||||
|
||||
def helper_method(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class MainClass:
|
||||
|
||||
def main_method(self):
|
||||
return HelperClass(self.name).helper_method()
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{file_path}
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class HelperClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class MainClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
||||
|
||||
def test_class_method_dependencies() -> None:
|
||||
file_path = Path(__file__).resolve()
|
||||
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="topologicalSort",
|
||||
file_path=str(file_path),
|
||||
parents=[FunctionParent(name="Graph", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize, file_path.parent.resolve(), original_code
|
||||
)
|
||||
|
||||
expected_read_write_context = """
|
||||
from __future__ import annotations
|
||||
|
||||
class Graph:
|
||||
|
||||
def topologicalSortUtil(self, v, visited, stack):
|
||||
visited[v] = True
|
||||
|
||||
for i in self.graph[v]:
|
||||
if visited[i] == False:
|
||||
self.topologicalSortUtil(i, visited, stack)
|
||||
|
||||
stack.insert(0, v)
|
||||
|
||||
def topologicalSort(self):
|
||||
visited = [False] * self.V
|
||||
stack = []
|
||||
|
||||
for i in range(self.V):
|
||||
if visited[i] == False:
|
||||
self.topologicalSortUtil(i, visited, stack)
|
||||
|
||||
# Print contents of stack
|
||||
return stack
|
||||
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{file_path}
|
||||
from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class Graph:
|
||||
def __init__(self, vertices):
|
||||
self.graph = defaultdict(list)
|
||||
self.V = vertices # No. of vertices
|
||||
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
||||
|
||||
def test_bubble_sort_helper() -> None:
|
||||
path_to_fto = (
|
||||
Path(__file__).resolve().parent.parent
|
||||
/ "code_to_optimize"
|
||||
/ "code_directories"
|
||||
/ "retriever"
|
||||
/ "bubble_sort_imported.py"
|
||||
)
|
||||
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="sort_from_another_file",
|
||||
file_path=str(path_to_fto),
|
||||
parents=[],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
with open(path_to_fto) as f:
|
||||
original_code = f.read()
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize, Path(__file__).resolve().parent.parent, original_code
|
||||
)
|
||||
|
||||
expected_read_write_context = """
|
||||
from bubble_sort_with_math import sorter
|
||||
import math
|
||||
|
||||
def sort_from_another_file(arr):
|
||||
sorted_arr = sorter(arr)
|
||||
return sorted_arr
|
||||
|
||||
|
||||
|
||||
def sorter(arr):
|
||||
arr.sort()
|
||||
x = math.sqrt(2)
|
||||
print(x)
|
||||
return arr
|
||||
|
||||
"""
|
||||
expected_read_only_context = ""
|
||||
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
||||
|
||||
def test_repo_helper() -> None:
|
||||
path_to_file = (
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "main.py"
|
||||
)
|
||||
path_to_utils = (
|
||||
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "utils.py"
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="fetch_and_process_data",
|
||||
file_path=str(path_to_file),
|
||||
parents=[],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
with open(path_to_file) as f:
|
||||
original_code = f.read()
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize, Path(__file__).resolve().parent.parent, original_code
|
||||
)
|
||||
expected_read_write_context = """
|
||||
import requests
|
||||
from globals import API_URL
|
||||
from utils import DataProcessor
|
||||
|
||||
def fetch_and_process_data():
|
||||
# Use the global variable for the request
|
||||
response = requests.get(API_URL)
|
||||
response.raise_for_status()
|
||||
|
||||
raw_data = response.text
|
||||
|
||||
# Use code from another file (utils.py)
|
||||
processor = DataProcessor()
|
||||
processed = processor.process_data(raw_data)
|
||||
processed = processor.add_prefix(processed)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
|
||||
def process_data(self, raw_data: str) -> str:
|
||||
\"\"\"Process raw data by converting it to uppercase.\"\"\"
|
||||
return raw_data.upper()
|
||||
|
||||
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
|
||||
\"\"\"Add a prefix to the processed data.\"\"\"
|
||||
return prefix + data
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:/Users/alvinryanputra/cf/codeflash/cli/code_to_optimize/code_directories/retriever/main.py
|
||||
if __name__ == "__main__":
|
||||
result = fetch_and_process_data()
|
||||
print("Processed data:", result)
|
||||
|
||||
```
|
||||
```python:{path_to_utils}
|
||||
import math
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
\"\"\"A class for processing data.\"\"\"
|
||||
|
||||
number = 1
|
||||
|
||||
def __init__(self, default_prefix: str = "PREFIX_"):
|
||||
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
||||
self.default_prefix = default_prefix
|
||||
self.number += math.log(self.number)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
||||
|
||||
def test_flavio_typed_code_helper() -> None:
|
||||
code = '''
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_KEY_T = TypeVar("_KEY_T")
|
||||
_STORE_T = TypeVar("_STORE_T")
|
||||
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
||||
"""Interface for cache backends used by the persistent cache decorator."""
|
||||
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
def hash_key(
|
||||
self,
|
||||
*,
|
||||
func: Callable[_P, Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> tuple[str, _KEY_T]: ...
|
||||
|
||||
def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401
|
||||
...
|
||||
|
||||
def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401
|
||||
...
|
||||
|
||||
def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ...
|
||||
|
||||
def delete(self, *, key: tuple[str, _KEY_T]) -> None: ...
|
||||
|
||||
def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ...
|
||||
|
||||
def get_cache_or_call(
|
||||
self,
|
||||
*,
|
||||
func: Callable[_P, Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
lifespan: datetime.timedelta,
|
||||
) -> Any: # noqa: ANN401
|
||||
"""
|
||||
Retrieve the cached results for a function call.
|
||||
|
||||
Args:
|
||||
----
|
||||
func (Callable[..., _R]): The function to retrieve cached results for.
|
||||
args (tuple[Any, ...]): The positional arguments passed to the function.
|
||||
kwargs (dict[str, Any]): The keyword arguments passed to the function.
|
||||
lifespan (datetime.timedelta): The maximum age of the cached results.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
_R: The cached results, if available.
|
||||
|
||||
"""
|
||||
if os.environ.get("NO_CACHE"):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
try:
|
||||
key = self.hash_key(func=func, args=args, kwargs=kwargs)
|
||||
except: # noqa: E722
|
||||
# If we can't create a cache key, we should just call the function.
|
||||
logging.warning("Failed to hash cache key for function: %s", func)
|
||||
return func(*args, **kwargs)
|
||||
result_pair = self.get(key=key)
|
||||
|
||||
if result_pair is not None:
|
||||
cached_time, result = result_pair
|
||||
if not os.environ.get("RE_CACHE") and (
|
||||
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
|
||||
):
|
||||
try:
|
||||
return self.decode(data=result)
|
||||
except CacheBackendDecodeError as e:
|
||||
logging.warning("Failed to decode cache data: %s", e)
|
||||
# If decoding fails we will treat this as a cache miss.
|
||||
# This might happens if underlying class definition of the data changes.
|
||||
self.delete(key=key)
|
||||
result = func(*args, **kwargs)
|
||||
try:
|
||||
self.put(key=key, data=self.encode(data=result))
|
||||
except CacheBackendEncodeError as e:
|
||||
logging.warning("Failed to encode cache data: %s", e)
|
||||
# If encoding fails, we should still return the result.
|
||||
return result
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
|
||||
|
||||
|
||||
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
||||
"""
|
||||
A decorator class that provides persistent caching functionality for a function.
|
||||
|
||||
Args:
|
||||
----
|
||||
func (Callable[_P, _R]): The function to be decorated.
|
||||
duration (datetime.timedelta): The duration for which the cached results should be considered valid.
|
||||
backend (_backend): The backend storage for the cached results.
|
||||
|
||||
Attributes:
|
||||
----------
|
||||
__wrapped__ (Callable[_P, _R]): The wrapped function.
|
||||
__duration__ (datetime.timedelta): The duration for which the cached results should be considered valid.
|
||||
__backend__ (_backend): The backend storage for the cached results.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
__wrapped__: Callable[_P, _R]
|
||||
__duration__: datetime.timedelta
|
||||
__backend__: _CacheBackendT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[_P, _R],
|
||||
duration: datetime.timedelta,
|
||||
) -> None:
|
||||
self.__wrapped__ = func
|
||||
self.__duration__ = duration
|
||||
self.__backend__ = AbstractCacheBackend()
|
||||
functools.update_wrapper(self, func)
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clears the cache for the wrapped function."""
|
||||
self.__backend__.del_func_cache(func=self.__wrapped__)
|
||||
|
||||
def no_cache_call(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
"""
|
||||
Calls the wrapped function without using the cache.
|
||||
|
||||
Args:
|
||||
----
|
||||
*args (_P.args): Positional arguments for the wrapped function.
|
||||
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
_R: The result of the wrapped function.
|
||||
|
||||
"""
|
||||
return self.__wrapped__(*args, **kwargs)
|
||||
|
||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
"""
|
||||
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
|
||||
|
||||
Args:
|
||||
----
|
||||
*args (_P.args): Positional arguments for the wrapped function.
|
||||
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
_R: The result of the wrapped function.
|
||||
|
||||
""" # noqa: E501
|
||||
if "NO_CACHE" in os.environ:
|
||||
return self.__wrapped__(*args, **kwargs)
|
||||
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
|
||||
return self.__backend__.get_cache_or_call(
|
||||
func=self.__wrapped__,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
lifespan=self.__duration__,
|
||||
)
|
||||
'''
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
file_path = Path(f.name).resolve()
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=file_path.parent.resolve(),
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path().resolve(),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="__call__",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="_PersistentCache", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize, opt.args.project_root, original_code
|
||||
)
|
||||
expected_read_write_context = """
|
||||
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
||||
|
||||
def get_cache_or_call(
|
||||
self,
|
||||
*,
|
||||
func: Callable[_P, Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
lifespan: datetime.timedelta,
|
||||
) -> Any: # noqa: ANN401
|
||||
\"\"\"
|
||||
Retrieve the cached results for a function call.
|
||||
|
||||
Args:
|
||||
----
|
||||
func (Callable[..., _R]): The function to retrieve cached results for.
|
||||
args (tuple[Any, ...]): The positional arguments passed to the function.
|
||||
kwargs (dict[str, Any]): The keyword arguments passed to the function.
|
||||
lifespan (datetime.timedelta): The maximum age of the cached results.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
_R: The cached results, if available.
|
||||
|
||||
\"\"\"
|
||||
if os.environ.get("NO_CACHE"):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
try:
|
||||
key = self.hash_key(func=func, args=args, kwargs=kwargs)
|
||||
except: # noqa: E722
|
||||
# If we can't create a cache key, we should just call the function.
|
||||
logging.warning("Failed to hash cache key for function: %s", func)
|
||||
return func(*args, **kwargs)
|
||||
result_pair = self.get(key=key)
|
||||
|
||||
if result_pair is not None:
|
||||
cached_time, result = result_pair
|
||||
if not os.environ.get("RE_CACHE") and (
|
||||
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
|
||||
):
|
||||
try:
|
||||
return self.decode(data=result)
|
||||
except CacheBackendDecodeError as e:
|
||||
logging.warning("Failed to decode cache data: %s", e)
|
||||
# If decoding fails we will treat this as a cache miss.
|
||||
# This might happens if underlying class definition of the data changes.
|
||||
self.delete(key=key)
|
||||
result = func(*args, **kwargs)
|
||||
try:
|
||||
self.put(key=key, data=self.encode(data=result))
|
||||
except CacheBackendEncodeError as e:
|
||||
logging.warning("Failed to encode cache data: %s", e)
|
||||
# If encoding fails, we should still return the result.
|
||||
return result
|
||||
|
||||
|
||||
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
||||
|
||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
\"\"\"
|
||||
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
|
||||
|
||||
Args:
|
||||
----
|
||||
*args (_P.args): Positional arguments for the wrapped function.
|
||||
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
|
||||
|
||||
Returns:
|
||||
-------
|
||||
_R: The result of the wrapped function.
|
||||
|
||||
\"\"\" # noqa: E501
|
||||
if "NO_CACHE" in os.environ:
|
||||
return self.__wrapped__(*args, **kwargs)
|
||||
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
|
||||
return self.__backend__.get_cache_or_call(
|
||||
func=self.__wrapped__,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
lifespan=self.__duration__,
|
||||
)
|
||||
"""
|
||||
expected_read_only_context = f'''
|
||||
```python:{file_path}
|
||||
_P = ParamSpec("_P")
|
||||
_KEY_T = TypeVar("_KEY_T")
|
||||
_STORE_T = TypeVar("_STORE_T")
|
||||
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
||||
"""Interface for cache backends used by the persistent cache decorator."""
|
||||
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
|
||||
|
||||
|
||||
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
||||
"""
|
||||
A decorator class that provides persistent caching functionality for a function.
|
||||
|
||||
Args:
|
||||
----
|
||||
func (Callable[_P, _R]): The function to be decorated.
|
||||
duration (datetime.timedelta): The duration for which the cached results should be considered valid.
|
||||
backend (_backend): The backend storage for the cached results.
|
||||
|
||||
Attributes:
|
||||
----------
|
||||
__wrapped__ (Callable[_P, _R]): The wrapped function.
|
||||
__duration__ (datetime.timedelta): The duration for which the cached results should be considered valid.
|
||||
__backend__ (_backend): The backend storage for the cached results.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
__wrapped__: Callable[_P, _R]
|
||||
__duration__: datetime.timedelta
|
||||
__backend__: _CacheBackendT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[_P, _R],
|
||||
duration: datetime.timedelta,
|
||||
) -> None:
|
||||
self.__wrapped__ = func
|
||||
self.__duration__ = duration
|
||||
self.__backend__ = AbstractCacheBackend()
|
||||
functools.update_wrapper(self, func)
|
||||
|
||||
```
|
||||
'''
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
Loading…
Reference in a new issue