First part of a refactor. cst_context and retriever implements logic to retrieve helper functions and build up the necessary CST trees/code as read_write and read_only context. In this commit is also a bunch of test files.

This commit is contained in:
Alvin Ryanputra 2024-12-17 11:30:07 -08:00
parent 91ecf113aa
commit 8b3ce0e9b9
17 changed files with 2380 additions and 11 deletions

View file

@ -0,0 +1,6 @@
from bubble_sort_with_math import sorter
def sort_from_another_file(arr):
sorted_arr = sorter(arr)
return sorted_arr

View file

@ -0,0 +1,8 @@
import math
def sorter(arr):
arr.sort()
x = math.sqrt(2)
print(x)
return arr

View file

@ -0,0 +1,2 @@
# Define a global variable
API_URL = "https://api.example.com/data"

View file

@ -0,0 +1,23 @@
import requests # Third-party library
from globals import API_URL # Global variable defined in another file
from utils import DataProcessor
def fetch_and_process_data():
# Use the global variable for the request
response = requests.get(API_URL)
response.raise_for_status()
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
processed = processor.add_prefix(processed)
return processed
if __name__ == "__main__":
result = fetch_and_process_data()
print("Processed data:", result)

View file

@ -0,0 +1,27 @@
import math
class DataProcessor:
"""A class for processing data."""
number = 1
def __init__(self, default_prefix: str = "PREFIX_"):
"""Initialize the DataProcessor with a default prefix."""
self.default_prefix = default_prefix
self.number += math.log(self.number)
def __repr__(self) -> str:
"""Return a string representation of the DataProcessor."""
return f"DataProcessor(default_prefix={self.default_prefix!r})"
def process_data(self, raw_data: str) -> str:
"""Process raw data by converting it to uppercase."""
return raw_data.upper()
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
"""Add a prefix to the processed data."""
return prefix + data
def do_something(self):
print("something")

View file

@ -74,6 +74,7 @@ class AiServiceClient:
def optimize_python_code(
self,
source_code: str,
dependency_code: str,
trace_id: str,
num_candidates: int = 10,
experiment_metadata: ExperimentMetadata | None = None,
@ -83,15 +84,20 @@ class AiServiceClient:
Parameters
----------
- source_code (str): The python code to optimize.
- num_variants (int): Number of optimization variants to generate. Default is 10.
- read_write_context (str) : The python code to be used as read-write context.
- read_only_context (str): The python code to be used as read-only context.
- trace_id (str): Trace id of optimization run
- num_candidates (int): Number of optimization variants to generate. Default is 10.
- experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization
Returns
-------
- List[Optimization]: A list of Optimization objects.
- List[OptimizationCandidate]: A list of Optimization Candidates.
"""
payload = {
"source_code": source_code,
"dependency_code": dependency_code,
"num_variants": num_candidates,
"trace_id": trace_id,
"python_version": platform.python_version(),

View file

@ -104,6 +104,72 @@ def add_needed_imports_from_module(
return dst_module_code
def add_needed_imports_from_module_2(
src_module_code: str,
dst_module_code: str,
src_path: Path,
dst_path: Path,
project_root: Path,
helper_functions_fully_qualified_names: list[str] | None = None,
) -> str:
"""Copy of add_needed_imports_from_module. will remove in a future refactor. This function simply changes the 'helper_functions' argument"""
src_module_code = delete___future___aliased_imports(src_module_code)
if helper_functions_fully_qualified_names is None:
helper_functions_fully_qualified_names = []
src_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, src_path)
dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path)
dst_context: CodemodContext = CodemodContext(
filename=src_path.name,
full_module_name=dst_module_and_package.name,
full_package_name=dst_module_and_package.package,
)
gatherer: GatherImportsVisitor = GatherImportsVisitor(
CodemodContext(
filename=src_path.name,
full_module_name=src_module_and_package.name,
full_package_name=src_module_and_package.package,
)
)
cst.parse_module(src_module_code).visit(gatherer)
try:
for mod in gatherer.module_imports:
AddImportsVisitor.add_needed_import(dst_context, mod)
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
for mod, obj_seq in gatherer.object_mapping.items():
for obj in obj_seq:
if f"{mod}.{obj}" in helper_functions_fully_qualified_names:
continue # Skip adding imports for helper functions already in the context
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
except Exception as e:
logger.exception(f"Error adding imports to destination module code: {e}")
return dst_module_code
for mod, asname in gatherer.module_aliases.items():
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
for mod, alias_pairs in gatherer.alias_mapping.items():
for alias_pair in alias_pairs:
if f"{mod}.{alias_pair[0]}" in helper_functions_fully_qualified_names:
continue
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
try:
parsed_module = cst.parse_module(dst_module_code)
except cst.ParserSyntaxError as e:
logger.exception(f"Syntax error in destination module code: {e}")
return dst_module_code # Return the original code if there's a syntax error
try:
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module)
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
return transformed_module.code.lstrip("\n")
except Exception as e:
logger.exception(f"Error adding imports to destination module code: {e}")
return dst_module_code
def get_code(functions_to_optimize: list[FunctionToOptimize]) -> tuple[str | None, set[tuple[str, str]]]:
"""Return the code for a function or methods in a Python module. functions_to_optimize is either a singleton
FunctionToOptimize instance, which represents either a function at the module level or a method of a class at the

View file

@ -54,6 +54,8 @@ class BestOptimization(BaseModel):
class CodeOptimizationContext(BaseModel):
code_to_optimize_with_helpers: str
read_write_context_code: str = ""
read_only_context_code: str = ""
contextual_dunder_methods: set[tuple[str, str]]
helper_functions: list[FunctionSource]
preexisting_objects: list[tuple[str, list[FunctionParent]]]

View file

@ -0,0 +1,260 @@
from __future__ import annotations
import libcst as cst
from pydantic import BaseModel, Field
class CSTContextNode(BaseModel):
"""A node in the context tree, representing a single node in the CST. This context tree is a part of the main CST tree that contains nodes that lead to a target function.
The corresponding cst_node is stored so that the cst can be rebuilt flexibly, based on whatever information is needed.
In the future, this tree can be used as the code replacer, since we know where the target functions are located in the tree.
"""
cst_node: cst.CSTNode | None
children: dict[str, list[CSTContextNode] | CSTContextNode] = Field(default_factory=dict)
is_target_function: bool = False
target_functions: set[str] = Field(default_factory=set)
class Config:
arbitrary_types_allowed = True
def add_child(self, section: str, child: CSTContextNode):
"""Add a child node to the specified section. sections are either body, orelse, finalbody, or handlers."""
original_section = getattr(self.cst_node, section, None)
if isinstance(original_section, (list, tuple)):
if section not in self.children:
self.children[section] = []
self.children[section].append(child)
else:
self.children[section] = child
def build_context_tree(context_node: CSTContextNode, prefix: str = "") -> bool:
"""Recursively builds a context tree from a CST, tracking target functions and their containing structures.
Args:
context_node: Current node in the context tree
prefix: Prefix to add to the target function names to create qualified names
Returns:
bool: True if a target function was found in this branch
"""
def process_node(node: cst.CSTNode, section: str) -> bool:
if isinstance(node, cst.ClassDef):
if prefix: # Don't go into nested classes
return False
class_node = CSTContextNode(cst_node=node, target_functions=context_node.target_functions)
if build_context_tree(class_node, f"{prefix}.{node.name.value}" if prefix else node.name.value):
context_node.add_child(section, class_node)
return True
elif isinstance(node, cst.FunctionDef):
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
if qualified_name in context_node.target_functions:
func_node = CSTContextNode(
cst_node=node, is_target_function=True, target_functions=context_node.target_functions
)
context_node.add_child(section, func_node)
return True
elif isinstance(node, cst.CSTNode):
other_node = CSTContextNode(cst_node=node, target_functions=context_node.target_functions)
if build_context_tree(other_node, prefix):
context_node.add_child(section, other_node)
return True
return False
has_target = False
node = context_node.cst_node
# Check each section directly
for section_name in ["body", "orelse", "finalbody", "handlers"]:
section_content = getattr(node, section_name, None)
if section_content is not None:
if isinstance(section_content, (list, tuple)):
has_target_list = [process_node(section_node, section_name) for section_node in section_content]
has_target |= any(has_target_list)
else:
has_target |= process_node(section_content, section_name)
return has_target
def find_containing_classes(code: str, target_functions: set[str]) -> CSTContextNode:
"""Parse the code and find all class definitions containing the target functions."""
root = CSTContextNode(cst_node=cst.parse_module(code), target_functions=target_functions)
if not build_context_tree(root):
raise ValueError("No target functions found in the provided code")
return root
def create_read_write_context(context_node: CSTContextNode) -> str:
"""Rebuilds a CST tree to create our read-write context"""
def rebuild_node(node: CSTContextNode) -> cst.CSTNode:
if node.is_target_function:
return node.cst_node
updates = {}
for section_name, children in node.children.items():
if isinstance(children, list):
updates[section_name] = [rebuild_node(child) for child in children]
else:
updates[section_name] = rebuild_node(children)
return node.cst_node.with_changes(**updates)
if not isinstance(context_node.cst_node, cst.Module):
raise ValueError("Root of context tree must be a Module node")
rebuilt_module = rebuild_node(context_node)
return rebuilt_module.code
def create_read_only_context(context_node: CSTContextNode) -> str:
"""Creates a read-only version of the context tree where:
- Global variables and typing information are preserved
- Class definitions preserve all information (variables, docstrings, etc.)
- Only dunder methods are preserved
- All other methods (including target functions) are removed completely
"""
def rebuild_non_context_node(node: cst.CSTNode) -> cst.CSTNode | None:
"""Recursively rebuild non-context nodes, preserving all statements except class and function definitions."""
if isinstance(node, (cst.ClassDef, cst.FunctionDef, cst.Import, cst.ImportFrom)):
return None
updates = {}
for section_name in ["body", "orelse", "finalbody", "handlers"]:
section_content = getattr(node, section_name, None)
if section_content is not None:
if isinstance(section_content, (list, tuple)):
rebuilt_children = []
for child in section_content:
rebuilt_child = rebuild_non_context_node(child)
if rebuilt_child is not None:
rebuilt_children.append(rebuilt_child)
if rebuilt_children:
updates[section_name] = rebuilt_children
else:
rebuilt_child = rebuild_non_context_node(section_content)
if rebuilt_child is not None:
updates[section_name] = rebuilt_child
if not updates:
return node
return node.with_changes(**updates)
def rebuild_node(node: CSTContextNode) -> cst.CSTNode | None:
# Remove target functions completely
if node.is_target_function:
return None
# For regular functions, remove them
if isinstance(node.cst_node, cst.FunctionDef):
return None
# For class definitions, preserve structure but remove non-dunder methods
if isinstance(node.cst_node, cst.ClassDef):
return rebuild_class(node)
updates = {}
for section_name in ["body", "orelse", "finalbody", "handlers"]:
section_content = getattr(node.cst_node, section_name, None)
if section_content is not None:
context_children = node.children.get(section_name, [])
context_children = [context_children] if not isinstance(context_children, list) else context_children
if isinstance(section_content, (list, tuple)):
rebuilt_children = []
# Create a map of context children positions
context_positions = {id(child.cst_node): i for i, child in enumerate(context_children)}
for i, child_node in enumerate(section_content):
# If this node is in context, use rebuild_node
if id(child_node) in context_positions:
context_idx = context_positions[id(child_node)]
rebuilt_child = rebuild_node(context_children[context_idx])
else:
# Otherwise use rebuild_non_context_node
rebuilt_child = rebuild_non_context_node(child_node)
if rebuilt_child is not None:
rebuilt_children.append(rebuilt_child)
if rebuilt_children:
updates[section_name] = rebuilt_children
else:
# Single node case
if context_children: # If we have a context child
rebuilt_child = rebuild_node(context_children[0])
else:
rebuilt_child = rebuild_non_context_node(section_content)
if rebuilt_child is not None:
updates[section_name] = rebuilt_child
if not updates:
return None
return node.cst_node.with_changes(**updates)
def is_dunder_method(func_node: cst.FunctionDef) -> bool:
"""Check if a function is a dunder method."""
name = func_node.name.value
return name.startswith("__") and name.endswith("__")
def rebuild_class(node: CSTContextNode) -> cst.ClassDef:
"""Rebuilds a class definition, preserving only structure, variables, and dunder methods."""
class_node = node.cst_node
new_body = []
if isinstance(class_node.body, cst.IndentedBlock):
body_statements = class_node.body.body
else:
body_statements = [class_node.body]
for stmt in body_statements:
if isinstance(stmt, cst.FunctionDef):
if is_dunder_method(stmt):
if f"{class_node.name.value}.{stmt.name.value}" in node.target_functions:
# target function is a dunder method already shown in read-write context
continue
# Preserve only dunder methods
new_body.append(stmt)
else:
# Keep all other class contents (variables, docstrings, etc.)
new_body.append(stmt)
if not new_body:
return None
return class_node.with_changes(body=cst.IndentedBlock(new_body))
if not isinstance(context_node.cst_node, cst.Module):
raise ValueError("Root of context tree must be a Module node")
rebuilt_module = rebuild_node(context_node)
if rebuilt_module is None:
return ""
return rebuilt_module.code
def print_tree(node: CSTContextNode, level: int = 0):
"""Helper function to visualize the full CST node structure recursively"""
indent = " " * level
print(f"\n{indent}CSTContextNode:")
print(f"{indent} is_target_function: {node.is_target_function}")
print(f"{indent} cst_node type: {type(node.cst_node)}")
print(f"{indent} children:")
for section_name, children in node.children.items():
print(f"{indent} {section_name}:")
if isinstance(children, list):
for child in children:
print_tree(child, level + 3)
else:
print_tree(children, level + 3)

View file

@ -78,6 +78,8 @@ if TYPE_CHECKING:
from codeflash.models.models import CoverageData, FunctionCalledInTest, FunctionSource, OptimizedCandidate
from codeflash.optimization import retriever
class Optimizer:
def __init__(self, args: Namespace) -> None:
@ -272,13 +274,15 @@ class Optimizer:
f"Generating new tests and optimizations for function {function_to_optimize.function_name}", transient=True
):
generated_results = self.generate_tests_and_optimizations(
code_context.code_to_optimize_with_helpers,
function_to_optimize,
code_context.helper_functions,
Path(original_module_path),
function_trace_id,
generated_test_paths,
function_to_optimize_ast,
code_to_optimize_with_helpers=code_context.code_to_optimize_with_helpers,
read_write_context_code=code_context.read_write_context_code,
read_only_context_code=code_context.read_only_context_code,
function_to_optimize=function_to_optimize,
helper_functions=code_context.helper_functions,
module_path=Path(original_module_path),
function_trace_id=function_trace_id,
generated_test_paths=generated_test_paths,
function_to_optimize_ast=function_to_optimize_ast,
run_experiment=should_run_experiment,
)
@ -709,9 +713,16 @@ class Optimizer:
)
preexisting_objects = find_preexisting_objects(code_to_optimize_with_helpers)
contextual_dunder_methods.update(helper_dunder_methods)
# Will eventually refactor to use this function instead of the above
read_write_context, read_only_context = retriever.get_code_optimization_context(
function_to_optimize, project_root, original_source_code
)
return Success(
CodeOptimizationContext(
code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,
read_write_context_code=read_write_context,
read_only_context_code=read_only_context,
contextual_dunder_methods=contextual_dunder_methods,
helper_functions=helper_functions,
preexisting_objects=preexisting_objects,
@ -795,6 +806,8 @@ class Optimizer:
def generate_tests_and_optimizations(
self,
code_to_optimize_with_helpers: str,
read_write_context_code: str,
read_only_context_code: str,
function_to_optimize: FunctionToOptimize,
helper_functions: list[FunctionSource],
module_path: Path,
@ -819,7 +832,8 @@ class Optimizer:
)
future_optimization_candidates = executor.submit(
self.aiservice_client.optimize_python_code,
code_to_optimize_with_helpers,
read_write_context_code,
read_only_context_code,
function_trace_id[:-4] + "EXP0" if run_experiment else function_trace_id,
N_CANDIDATES,
ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None,
@ -833,7 +847,8 @@ class Optimizer:
if run_experiment:
future_candidates_exp = executor.submit(
self.local_aiservice_client.optimize_python_code,
code_to_optimize_with_helpers,
read_write_context_code,
read_only_context_code,
function_trace_id[:-4] + "EXP1",
N_CANDIDATES,
ExperimentMetadata(id=self.experiment_id, group="experiment"),

View file

@ -0,0 +1,115 @@
import os
from collections import defaultdict
from pathlib import Path
import jedi
import libcst as cst
from jedi.api.classes import Name
from returns.result import Result
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module_2
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext
from codeflash.optimization.cst_context import (
CSTContextNode,
build_context_tree,
create_read_only_context,
create_read_write_context,
)
from codeflash.optimization.function_context import belongs_to_class, belongs_to_function
def get_code_optimization_context(
function_to_optimize: FunctionToOptimize, project_root_path: Path, original_source_code: str
) -> Result[CodeOptimizationContext, str]:
function_name = function_to_optimize.function_name
file_path = function_to_optimize.file_path
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
file_path_to_qualified_function_names = defaultdict(set)
file_path_to_qualified_function_names[file_path].add(function_to_optimize.qualified_name)
read_write_list = []
read_only_list = []
read_write_string = ""
read_only_string = ""
names = []
for ref in script.get_names(all_scopes=True, definitions=False, references=True):
if ref.full_name:
if function_to_optimize.parents:
# Check if the reference belongs to the specified class when FunctionParent is provided
if belongs_to_class(ref, function_to_optimize.parents[-1].name) and belongs_to_function(
ref, function_name
):
names.append(ref)
elif belongs_to_function(ref, function_name):
names.append(ref)
for name in names:
try:
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
except Exception as e:
try:
logger.exception(f"Error while getting definition for {name.full_name}: {e}")
except Exception as e:
# name.full_name can also throw exceptions sometimes
logger.exception(f"Error while getting definition: {e}")
definitions = []
if definitions:
# TODO: there can be multiple definitions, see how to handle such cases
definition = definitions[0]
definition_path = definition.module_path
# The definition is part of this project and not defined within the original function
if (
str(definition_path).startswith(str(project_root_path) + os.sep)
and not path_belongs_to_site_packages(definition_path)
and definition.full_name
and not belongs_to_function(definition, function_name)
):
file_path_to_qualified_function_names[definition_path].add(
get_qualified_name(definition.module_name, definition.full_name)
)
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
try:
og_code_containing_helpers = file_path.read_text("utf8")
context_tree_root = cst.parse_module(og_code_containing_helpers)
except Exception as e:
logger.exception(f"Error while parsing {file_path}: {e}")
continue
context_tree_root = CSTContextNode(cst_node=context_tree_root, target_functions=qualified_function_names)
if not build_context_tree(context_tree_root, ""):
logger.debug(
f"{qualified_function_names} was not found in {file_path} when retrieving code optimization context"
)
continue
read_write_context_string = f"{create_read_write_context(context_tree_root)}"
if read_write_context_string:
read_write_string += f"\n{read_write_context_string}"
read_write_string = add_needed_imports_from_module_2(
og_code_containing_helpers,
read_write_string,
file_path,
file_path,
project_root_path,
list(qualified_function_names),
)
read_only_context_string = f"{create_read_only_context(context_tree_root)}\n"
read_only_context_string_with_imports = add_needed_imports_from_module_2(
og_code_containing_helpers,
read_only_context_string,
file_path,
file_path,
project_root_path,
list(qualified_function_names),
)
if read_only_context_string_with_imports:
read_only_list.append(f"```python:{file_path}\n{read_only_context_string_with_imports}```")
read_only_string = "\n".join(read_only_list)
print("read_write_string: \n\n", read_write_string)
print("read_only_string: \n\n", read_only_string)
return read_write_string, read_only_string

View file

@ -0,0 +1,261 @@
from textwrap import dedent
import libcst as cst
from codeflash.optimization.cst_context import find_containing_classes, print_tree
def test_basic_function_identification():
"""Test that target functions are correctly identified."""
code = """
def target_func():
pass
def other_func():
pass
"""
result = find_containing_classes(dedent(code), {"target_func"})
# Should find exactly one function
assert len(result.children["body"]) == 1
assert result.children["body"][0].is_target_function
assert isinstance(result.children["body"][0].cst_node, cst.FunctionDef)
assert result.children["body"][0].cst_node.name.value == "target_func"
def test_class_with_target_function():
"""Test that classes containing target functions are included with only the class definition."""
code = """
class TestClass:
def target_method(self):
pass
def other_method(self):
pass
"""
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
print_tree(result)
assert len(result.children["body"]) == 1
class_node = result.children["body"][0]
assert isinstance(class_node.cst_node, cst.ClassDef)
assert class_node.cst_node.name.value == "TestClass"
assert len(class_node.children["body"]) == 1
assert class_node.children["body"][0].is_target_function
assert class_node.children["body"][0].cst_node.name.value == "target_method"
def test_class_without_target_function():
"""Test that classes without target functions are not included."""
code = """
class TestClass:
def method1(self):
pass
def method2(self):
pass
"""
result = find_containing_classes(dedent(code), {"other_func"})
# Should have no children since no target functions were found
assert "body" not in result.children or not result.children["body"]
def test_control_flow_structures():
"""Test handling of control flow structures containing target functions."""
code = """
if True:
def target_func():
pass
else:
def other_func():
pass
try:
def another_target():
pass
except Exception:
def handler_func():
pass
finally:
def cleanup_func():
pass
"""
result = find_containing_classes(dedent(code), {"target_func", "another_target"})
assert len(result.children["body"]) == 2
if_context_node = result.children["body"][0]
assert isinstance(if_context_node.cst_node, cst.If)
target_context_node = if_context_node.children["body"][0]
assert target_context_node.cst_node.name.value == "target_func"
assert target_context_node.is_target_function
try_context_node = result.children["body"][1]
assert isinstance(try_context_node.cst_node, cst.Try)
another_target_context_node = try_context_node.children["body"][0]
assert another_target_context_node.cst_node.name.value == "another_target"
assert another_target_context_node.is_target_function
def test_nested_classes():
"""Test handling of nested classes with target functions."""
code = """
class OuterClass:
class InnerClass:
def target_method(self):
pass
def other_method(self):
pass
"""
result = find_containing_classes(dedent(code), {"OuterClass.InnerClass.target_method"})
# Verify the class hierarchy
assert len(result.children["body"]) == 1
outer_class = result.children["body"][0]
assert isinstance(outer_class.cst_node, cst.ClassDef)
assert outer_class.cst_node.name.value == "OuterClass"
assert len(outer_class.children["body"]) == 1
inner_class = outer_class.children["body"][0]
assert isinstance(inner_class.cst_node, cst.ClassDef)
assert inner_class.cst_node.name.value == "InnerClass"
assert len(inner_class.children["body"]) == 1
target_method = inner_class.children["body"][0]
assert target_method.is_target_function
assert target_method.cst_node.name.value == "target_method"
def test_no_classes():
"""Test handling of target functions without any classes."""
code = """
def function1():
pass
def target_function():
def nested_function():
pass
return nested_function
def function2():
pass
"""
result = find_containing_classes(dedent(code), {"target_function"})
# Should find only the target function
assert len(result.children["body"]) == 1
assert result.children["body"][0].is_target_function
assert result.children["body"][0].cst_node.name.value == "target_function"
def test_no_classes_if_else():
"""Test handling of target functions in if/else blocks."""
code = """
def function1():
pass
if x:
def target_function():
return "hello"
else:
def function2():
pass
"""
result = find_containing_classes(dedent(code), {"target_function"})
assert result.children["body"][0].children["body"][0].is_target_function
assert result.children["body"][0].children["body"][0].cst_node.name.value == "target_function"
def test_no_classes_else():
"""Test handling of target functions in else blocks."""
code = """
def function1():
pass
if x:
x += 2
else:
def target_function():
return "hello"
"""
result = find_containing_classes(dedent(code), {"target_function"})
assert result.children["body"][0].children["orelse"][0].is_target_function
assert result.children["body"][0].children["orelse"][0].cst_node.name.value == "target_function"
def test_comments_and_decorators():
"""Test that comments and decorators are preserved."""
code = """
# Top level comment
@decorator
class TestClass:
# Class comment
@method_decorator
def target_method(self):
# Method comment
pass
"""
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
# Verify class has decorator
assert len(result.children["body"]) == 1
class_node = result.children["body"][0]
assert len(class_node.cst_node.decorators) == 1
# Verify method has decorator
method_node = class_node.children["body"][0]
assert len(method_node.cst_node.decorators) == 1
def test_same_name_different_paths():
"""Test handling of functions with same name but different qualified paths."""
code = """
class ClassA:
def process(self):
pass
class ClassB:
def process(self):
pass
def process():
pass
class Outer:
class Inner:
def process(self):
pass
def process(self):
pass
"""
# Test finding specific instances
result = find_containing_classes(dedent(code), {"ClassA.process", "Outer.Inner.process"})
# Test for just top-level process
result_top = find_containing_classes(dedent(code), {"process"})
assert len(result_top.children["body"]) == 1
assert result_top.children["body"][0].is_target_function
assert result_top.children["body"][0].cst_node.name.value == "process"
# Test for just Outer.process
result_outer = find_containing_classes(dedent(code), {"Outer.process"})
assert result_outer.children["body"][0].cst_node.name.value == "Outer"
assert result_outer.children["body"][0].children["body"][0].cst_node.name.value == "process"
# Test for just Inner.process
result_inner = find_containing_classes(dedent(code), {"Outer.Inner.process"})
outer = result_inner.children["body"][0]
assert outer.cst_node.name.value == "Outer"
inner = outer.children["body"][0]
assert inner.cst_node.name.value == "Inner"
assert inner.children["body"][0].cst_node.name.value == "process"

409
tests/test_context_cst.py Normal file
View file

@ -0,0 +1,409 @@
from typing import Set
import libcst as cst
from codeflash.optimization.context_cst import (
prune_module, # replace 'yourmodule' with the actual module where prune_module is defined
)
def test_top_level_target_function():
code = """
def foo():
pass
def bar():
pass
"""
module = cst.parse_module(code)
targets: Set[str] = {"bar"}
pruned = prune_module(module, targets)
expected = """
def bar():
pass
"""
assert pruned.code.strip() == expected.strip()
def test_no_targets_found():
code = """
def foo():
pass
x = 10
"""
module = cst.parse_module(code)
targets: Set[str] = {"bar"} # 'bar' doesn't exist in code
pruned = prune_module(module, targets)
expected = "" # no targets found, return empty module
assert pruned.code.strip() == expected.strip()
def test_class_with_target_function():
code = """
class MyClass:
def helper(self):
pass
def target_method(self):
return 42
def unrelated():
pass
"""
module = cst.parse_module(code)
targets: Set[str] = {"MyClass.target_method"}
pruned = prune_module(module, targets)
# We expect to keep MyClass and only target_method in it
expected = """
class MyClass:
def target_method(self):
return 42
"""
assert pruned.code.strip() == expected.strip()
def test_nested_class_with_target_function():
code = """
class Outer:
def outer_method(self):
pass
class Inner:
def inner_helper(self):
pass
def target_func(self):
print("Target")
def top_level():
pass
"""
module = cst.parse_module(code)
targets: Set[str] = {"Outer.Inner.target_func"}
pruned = prune_module(module, targets)
# We must keep Outer, Inner, and the target_func inside Inner
expected = """
class Outer:
class Inner:
def target_func(self):
print("Target")
"""
assert pruned.code.strip() == expected.strip()
def test_if_statements_leading_to_target():
code = """
def foo():
if True:
def target():
return "yes"
else:
print("nope")
def bar():
pass
"""
module = cst.parse_module(code)
targets: Set[str] = {"target"}
pruned = prune_module(module, targets)
# We keep foo, because inside its if body there's the target function.
# The else is removed as it doesn't lead to a target. bar is removed as well.
expected = ""
assert pruned.code.strip() == expected.strip()
def test_class_with_no_target_functions():
code = """
class A:
def no_target(self):
pass
x = 5
"""
module = cst.parse_module(code)
targets: Set[str] = {"SomeOtherClass.some_func"}
pruned = prune_module(module, targets)
expected = "" # no targets in code, empty result
assert pruned.code.strip() == expected.strip()
def test_class_in_else_block():
code = """
if y is False:
x = 10
else:
class MyClass:
def not_target(self):
pass
def target_func(self):
return "found me!"
"""
module = cst.parse_module(code)
targets: Set[str] = {"MyClass.target_func"}
pruned = prune_module(module, targets)
# Even though MyClass is in the else block, we have a target method inside it.
# We expect to keep the wrapper function, the else block, and the class with just the target method.
expected = """
if y is False:
pass
else:
class MyClass:
def target_func(self):
return "found me!"
"""
assert pruned.code.strip() == expected.strip()
def test_class_in_if_block():
code = """
if y is False:
class MyClass:
def not_target(self):
pass
def target_func(self):
return "found me!"
else:
x = 10
"""
module = cst.parse_module(code)
targets: Set[str] = {"MyClass.target_func"}
pruned = prune_module(module, targets)
# Even though MyClass is in the else block, we have a target method inside it.
# We expect to keep the wrapper function, the else block, and the class with just the target method.
expected = """
if y is False:
class MyClass:
def target_func(self):
return "found me!"
else:
pass
"""
assert pruned.code.strip() == expected.strip()
def test_functions_same_name_different_scopes():
code = """
def foo():
return "top-level"
class Outer:
def foo():
return "in Outer"
class Another:
def foo():
return "in Another"
"""
module = cst.parse_module(code)
# Only match the "Outer.foo" function, not the top-level "foo" or "Another.foo"
targets: Set[str] = {"Outer.foo"}
pruned = prune_module(module, targets)
expected = """
class Outer:
def foo():
return "in Outer"
"""
assert pruned.code.strip() == expected.strip()
def test_if_elif_else_block():
code = """
if condition:
def not_target():
pass
elif other_condition:
def another_not_target():
pass
else:
def target_in_else():
return "Found"
def unrelated():
return "no"
"""
module = cst.parse_module(code)
targets: Set[str] = {"target_in_else"}
pruned = prune_module(module, targets)
# We keep the whole if-elif-else structure but prune out the non-target branches.
expected = """
if condition:
pass
elif other_condition:
pass
else:
def target_in_else():
return "Found"
"""
assert pruned.code.strip() == expected.strip()
def test_nested_if_in_else_block():
code = """
if top_level:
def no_target():
return 1
else:
if nested_condition:
def target_func():
return "nested target"
else:
def another_no():
pass
def outside():
pass
"""
module = cst.parse_module(code)
targets: Set[str] = {"target_func"}
pruned = prune_module(module, targets)
# We keep the top if-else structure, and inside the else, we keep the nested if block with the target.
expected = """
if top_level:
pass
else:
if nested_condition:
def target_func():
return "nested target"
"""
assert pruned.code.strip() == expected.strip()
def test_class_in_try_block():
code = """
try:
class MyClass:
def target_method(self):
return "reached!"
except ValueError:
def not_target():
return "no"
else:
def also_not_target():
return "no"
finally:
def final_not_target():
return "no"
"""
module = cst.parse_module(code)
targets: Set[str] = {"MyClass.target_method"}
pruned = prune_module(module, targets)
# We keep the try block because it contains the target class,
# remove except, else, and finally since they don't lead to the target.
expected = """
try:
class MyClass:
def target_method(self):
return "reached!"
except ValueError:
pass
else:
pass
finally:
pass
"""
assert pruned.code.strip() == expected.strip()
def test_target_in_except_block():
code = """
try:
def no_target():
return "no"
except KeyError:
def target_func():
return "caught target"
except ValueError:
def also_no():
return "no"
else:
def no_again():
return "no"
finally:
def final_no():
return "no"
"""
module = cst.parse_module(code)
targets: Set[str] = {"target_func"}
pruned = prune_module(module, targets)
# Remove the try, else, and finally bodies that don't lead to the target.
# Keep the except KeyError block with the target, and leave other except blocks as pass.
expected = """
try:
pass
except KeyError:
def target_func():
return "caught target"
except ValueError:
pass
else:
pass
finally:
pass
"""
assert pruned.code.strip() == expected.strip()
def test_target_in_multiple_except_blocks():
code = """
try:
def first_no():
return 1
except KeyError:
def second_no():
return 2
except ValueError:
def target_in_value():
return 3
except TypeError:
def another_no():
return 4
else:
def else_no():
return 5
finally:
def final_no():
return 6
"""
module = cst.parse_module(code)
targets: Set[str] = {"target_in_value"}
pruned = prune_module(module, targets)
# We only keep the try/except/finally structure necessary to reach target_in_value.
# The except for ValueError must remain with the target, others become pass.
expected = """
try:
pass
except KeyError:
pass
except ValueError:
def target_in_value():
return 3
except TypeError:
pass
else:
pass
finally:
pass
"""
assert pruned.code.strip() == expected.strip()

View file

@ -0,0 +1,339 @@
from textwrap import dedent
import pytest
from codeflash.optimization.cst_context import create_read_only_context, find_containing_classes, print_tree
def test_basic_class():
code = """
class TestClass:
class_var = "value"
def target_method(self):
print("This should be stubbed")
def other_method(self):
print("This too")
"""
expected = """
class TestClass:
class_var = "value"
"""
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
output = create_read_only_context(result)
assert dedent(expected).strip() == output.strip()
def test_dunder_methods():
code = """
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
def target_method(self):
print("stub me")
"""
expected = """
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
"""
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
output = create_read_only_context(result)
assert dedent(expected).strip() == output.strip()
def test_target_in_nested_class():
"""Test that attempting to find a target in a nested class raises an error."""
code = """
class Outer:
outer_var = 1
class Inner:
inner_var = 2
def target_method(self):
print("stub this")
"""
with pytest.raises(ValueError, match="No target functions found in the provided code"):
find_containing_classes(dedent(code), {"Outer.Inner.target_method"})
def test_docstrings():
code = """
class TestClass:
\"\"\"Class docstring.\"\"\"
def target_method(self):
\"\"\"Method docstring.\"\"\"
print("stub this")
def other_method(self):
\"\"\"Other docstring.\"\"\"
print("stub this too")
"""
expected = """
class TestClass:
\"\"\"Class docstring.\"\"\"
"""
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
output = create_read_only_context(result)
assert dedent(expected).strip() == output.strip()
def test_method_signatures():
code = """
class TestClass:
@property
def target_method(self) -> str:
\"\"\"Property docstring.\"\"\"
return "value"
@classmethod
def class_method(cls, param: int = 42) -> None:
print("stub this")
"""
expected = """"""
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
output = create_read_only_context(result)
print(output)
assert dedent(expected).strip() == output.strip()
def test_multiple_top_level_targets():
code = """
class TestClass:
def target1(self):
print("stub 1")
def target2(self):
print("stub 2")
def __init__(self):
self.x = 42
"""
expected = """
class TestClass:
def __init__(self):
self.x = 42
"""
result = find_containing_classes(dedent(code), {"TestClass.target1", "TestClass.target2"})
output = create_read_only_context(result)
assert dedent(expected).strip() == output.strip()
def test_class_annotations():
code = """
class TestClass:
var1: int = 42
var2: str
def target_method(self) -> None:
self.var2 = "test"
"""
expected = """
class TestClass:
var1: int = 42
var2: str
"""
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
output = create_read_only_context(result)
assert dedent(expected).strip() == output.strip()
def test_class_annotations_if():
code = """
if True:
class TestClass:
var1: int = 42
var2: str
def target_method(self) -> None:
self.var2 = "test"
"""
expected = """
if True:
class TestClass:
var1: int = 42
var2: str
"""
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
output = create_read_only_context(result)
assert dedent(expected).strip() == output.strip()
def test_class_annotations_try():
code = """
try:
class TestClass:
var1: int = 42
var2: str
def target_method(self) -> None:
self.var2 = "test"
except Exception:
continue
"""
expected = """
try:
class TestClass:
var1: int = 42
var2: str
except Exception:
continue
"""
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
output = create_read_only_context(result)
assert dedent(expected).strip() == output.strip()
def test_class_annotations_else():
code = """
if x is True:
class TestClass:
var1: int = 42
var2: str
def wrong_method(self) -> None:
print("wrong")
else:
class TestClass:
var1: int = 42
var2: str
def target_method(self) -> None:
self.var2 = "test"
"""
expected = """
if x is True:
class TestClass:
var1: int = 42
var2: str
def wrong_method(self) -> None:
print("wrong")
else:
class TestClass:
var1: int = 42
var2: str
"""
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
print_tree(result)
output = create_read_only_context(result)
assert dedent(expected).strip() == output.strip()
def test_top_level_functions():
code = """
def target_function(self) -> None:
self.var2 = "test"
def some_function():
print("wow")
"""
expected = """"""
result = find_containing_classes(dedent(code), {"target_function"})
output = create_read_only_context(result)
assert dedent(expected).strip() == output.strip()
def test_module_scope_var():
code = """
y = 3
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
def target_method(self):
print("stub me")
"""
expected = """
y = 3
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
"""
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
output = create_read_only_context(result)
assert dedent(expected).strip() == output.strip()
def test_module_scope_var():
code = """
if True:
y = 3
class OtherClass:
def this_method(self):
print("this method")
def __init__(self):
self.y = y
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
def target_method(self):
print("stub me")
"""
expected = """
if True:
y = 3
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
"""
result = find_containing_classes(dedent(code), {"TestClass.target_method"})
output = create_read_only_context(result)
assert dedent(expected).strip() == output.strip()

View file

@ -0,0 +1,207 @@
from textwrap import dedent
import pytest
from codeflash.optimization.cst_context import create_read_write_context, find_containing_classes, print_tree
def test_simple_function():
code = """
def target_function():
x = 1
y = 2
return x + y
"""
root = find_containing_classes(dedent(code), {"target_function"})
result = create_read_write_context(root)
expected = dedent("""
def target_function():
x = 1
y = 2
return x + y
""")
assert result.strip() == expected.strip()
def test_class_method():
code = """
class MyClass:
def target_function(self):
x = 1
y = 2
return x + y
"""
root = find_containing_classes(dedent(code), {"MyClass.target_function"})
result = create_read_write_context(root)
expected = dedent("""
class MyClass:
def target_function(self):
x = 1
y = 2
return x + y
""")
assert result.strip() == expected.strip()
def test_class_with_attributes():
code = """
class MyClass:
x: int = 1
y: str = "hello"
def target_method(self):
return self.x + 42
def other_method(self):
print("this should be excluded")
"""
root = find_containing_classes(dedent(code), {"MyClass.target_method"})
result = create_read_write_context(root)
expected = dedent("""
class MyClass:
def target_method(self):
return self.x + 42
""")
assert result.strip() == expected.strip()
def test_basic_class_structure():
"""Test that nested classes are ignored for target function search."""
code = """
class Outer:
x = 1
def target_method(self):
return 42
class Inner:
y = 2
def not_findable(self):
return 42
"""
root = find_containing_classes(dedent(code), {"Outer.target_method"})
result = create_read_write_context(root)
expected = dedent("""
class Outer:
def target_method(self):
return 42
""")
assert result.strip() == expected.strip()
def test_top_level_targets():
code = """
class OuterClass:
x = 1
def method1(self):
return self.x
def target_function():
return 42
"""
root = find_containing_classes(dedent(code), {"target_function"})
result = create_read_write_context(root)
expected = dedent("""
def target_function():
return 42
""")
assert result.strip() == expected.strip()
def test_multiple_top_level_classes():
code = """
class ClassA:
def process(self):
return "A"
class ClassB:
def process(self):
return "B"
class ClassC:
def process(self):
return "C"
"""
root = find_containing_classes(dedent(code), {"ClassA.process", "ClassC.process"})
result = create_read_write_context(root)
expected = dedent("""
class ClassA:
def process(self):
return "A"
class ClassC:
def process(self):
return "C"
""")
assert result.strip() == expected.strip()
def test_try_except_structure():
code = """
try:
class TargetClass:
def target_method(self):
return 42
except ValueError:
class ErrorClass:
def handle_error(self):
print("error")
"""
root = find_containing_classes(dedent(code), {"TargetClass.target_method"})
print_tree(root)
result = create_read_write_context(root)
expected = dedent("""
try:
class TargetClass:
def target_method(self):
return 42
except ValueError:
class ErrorClass:
def handle_error(self):
print("error")
""")
assert result.strip() == expected.strip()
def test_dunder_method():
code = """
class MyClass:
def __init__(self):
self.x = 1
def other_method(self):
return "other"
def target_method(self):
return f"Value: {self.x}"
"""
root = find_containing_classes(dedent(code), {"MyClass.target_method"})
result = create_read_write_context(root)
expected = dedent("""
class MyClass:
def target_method(self):
return f"Value: {self.x}"
""")
assert result.strip() == expected.strip()
def test_no_targets_found():
code = """
class MyClass:
def method(self):
pass
class Inner:
def target(self):
pass
"""
with pytest.raises(ValueError, match="No target functions found in the provided code"):
find_containing_classes(dedent(code), {"MyClass.Inner.target"})

623
tests/test_retriever.py Normal file
View file

@ -0,0 +1,623 @@
from __future__ import annotations
import tempfile
from argparse import Namespace
from collections import defaultdict
from pathlib import Path
from textwrap import dedent
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.optimization.optimizer import Optimizer
from codeflash.optimization.retriever import get_code_optimization_context
class HelperClass:
def __init__(self, name):
self.name = name
def innocent_bystander(self):
pass
def helper_method(self):
return self.name
class MainClass:
def __init__(self, name):
self.name = name
def main_method(self):
return HelperClass(self.name).helper_method()
class Graph:
def __init__(self, vertices):
self.graph = defaultdict(list)
self.V = vertices # No. of vertices
def addEdge(self, u, v):
self.graph[u].append(v)
def topologicalSortUtil(self, v, visited, stack):
visited[v] = True
for i in self.graph[v]:
if visited[i] == False:
self.topologicalSortUtil(i, visited, stack)
stack.insert(0, v)
def topologicalSort(self):
visited = [False] * self.V
stack = []
for i in range(self.V):
if visited[i] == False:
self.topologicalSortUtil(i, visited, stack)
# Print contents of stack
return stack
def test_code_replacement10() -> None:
file_path = Path(__file__).resolve()
func_top_optimize = FunctionToOptimize(
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
)
original_code = file_path.read_text()
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize=func_top_optimize, project_root_path=file_path.parent, original_source_code=original_code
)
expected_read_write_context = """
from __future__ import annotations
class HelperClass:
def helper_method(self):
return self.name
class MainClass:
def main_method(self):
return HelperClass(self.name).helper_method()
"""
expected_read_only_context = f"""
```python:{file_path}
from __future__ import annotations
class HelperClass:
def __init__(self, name):
self.name = name
class MainClass:
def __init__(self, name):
self.name = name
```
"""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
def test_class_method_dependencies() -> None:
file_path = Path(__file__).resolve()
function_to_optimize = FunctionToOptimize(
function_name="topologicalSort",
file_path=str(file_path),
parents=[FunctionParent(name="Graph", type="ClassDef")],
starting_line=None,
ending_line=None,
)
with open(file_path) as f:
original_code = f.read()
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, file_path.parent.resolve(), original_code
)
expected_read_write_context = """
from __future__ import annotations
class Graph:
def topologicalSortUtil(self, v, visited, stack):
visited[v] = True
for i in self.graph[v]:
if visited[i] == False:
self.topologicalSortUtil(i, visited, stack)
stack.insert(0, v)
def topologicalSort(self):
visited = [False] * self.V
stack = []
for i in range(self.V):
if visited[i] == False:
self.topologicalSortUtil(i, visited, stack)
# Print contents of stack
return stack
"""
expected_read_only_context = f"""
```python:{file_path}
from __future__ import annotations
from collections import defaultdict
class Graph:
def __init__(self, vertices):
self.graph = defaultdict(list)
self.V = vertices # No. of vertices
```
"""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
def test_bubble_sort_helper() -> None:
path_to_fto = (
Path(__file__).resolve().parent.parent
/ "code_to_optimize"
/ "code_directories"
/ "retriever"
/ "bubble_sort_imported.py"
)
function_to_optimize = FunctionToOptimize(
function_name="sort_from_another_file",
file_path=str(path_to_fto),
parents=[],
starting_line=None,
ending_line=None,
)
with open(path_to_fto) as f:
original_code = f.read()
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, Path(__file__).resolve().parent.parent, original_code
)
expected_read_write_context = """
from bubble_sort_with_math import sorter
import math
def sort_from_another_file(arr):
sorted_arr = sorter(arr)
return sorted_arr
def sorter(arr):
arr.sort()
x = math.sqrt(2)
print(x)
return arr
"""
expected_read_only_context = ""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
def test_repo_helper() -> None:
path_to_file = (
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "main.py"
)
path_to_utils = (
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "utils.py"
)
function_to_optimize = FunctionToOptimize(
function_name="fetch_and_process_data",
file_path=str(path_to_file),
parents=[],
starting_line=None,
ending_line=None,
)
with open(path_to_file) as f:
original_code = f.read()
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, Path(__file__).resolve().parent.parent, original_code
)
expected_read_write_context = """
import requests
from globals import API_URL
from utils import DataProcessor
def fetch_and_process_data():
# Use the global variable for the request
response = requests.get(API_URL)
response.raise_for_status()
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
processed = processor.add_prefix(processed)
return processed
class DataProcessor:
def process_data(self, raw_data: str) -> str:
\"\"\"Process raw data by converting it to uppercase.\"\"\"
return raw_data.upper()
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
\"\"\"Add a prefix to the processed data.\"\"\"
return prefix + data
"""
expected_read_only_context = f"""
```python:/Users/alvinryanputra/cf/codeflash/cli/code_to_optimize/code_directories/retriever/main.py
if __name__ == "__main__":
result = fetch_and_process_data()
print("Processed data:", result)
```
```python:{path_to_utils}
import math
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
number = 1
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
"""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
def test_flavio_typed_code_helper() -> None:
code = '''
_P = ParamSpec("_P")
_KEY_T = TypeVar("_KEY_T")
_STORE_T = TypeVar("_STORE_T")
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
"""Interface for cache backends used by the persistent cache decorator."""
def __init__(self) -> None: ...
def hash_key(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> tuple[str, _KEY_T]: ...
def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401
...
def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401
...
def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ...
def delete(self, *, key: tuple[str, _KEY_T]) -> None: ...
def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ...
def get_cache_or_call(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
lifespan: datetime.timedelta,
) -> Any: # noqa: ANN401
"""
Retrieve the cached results for a function call.
Args:
----
func (Callable[..., _R]): The function to retrieve cached results for.
args (tuple[Any, ...]): The positional arguments passed to the function.
kwargs (dict[str, Any]): The keyword arguments passed to the function.
lifespan (datetime.timedelta): The maximum age of the cached results.
Returns:
-------
_R: The cached results, if available.
"""
if os.environ.get("NO_CACHE"):
return func(*args, **kwargs)
try:
key = self.hash_key(func=func, args=args, kwargs=kwargs)
except: # noqa: E722
# If we can't create a cache key, we should just call the function.
logging.warning("Failed to hash cache key for function: %s", func)
return func(*args, **kwargs)
result_pair = self.get(key=key)
if result_pair is not None:
cached_time, result = result_pair
if not os.environ.get("RE_CACHE") and (
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
):
try:
return self.decode(data=result)
except CacheBackendDecodeError as e:
logging.warning("Failed to decode cache data: %s", e)
# If decoding fails we will treat this as a cache miss.
# This might happens if underlying class definition of the data changes.
self.delete(key=key)
result = func(*args, **kwargs)
try:
self.put(key=key, data=self.encode(data=result))
except CacheBackendEncodeError as e:
logging.warning("Failed to encode cache data: %s", e)
# If encoding fails, we should still return the result.
return result
_P = ParamSpec("_P")
_R = TypeVar("_R")
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
"""
A decorator class that provides persistent caching functionality for a function.
Args:
----
func (Callable[_P, _R]): The function to be decorated.
duration (datetime.timedelta): The duration for which the cached results should be considered valid.
backend (_backend): The backend storage for the cached results.
Attributes:
----------
__wrapped__ (Callable[_P, _R]): The wrapped function.
__duration__ (datetime.timedelta): The duration for which the cached results should be considered valid.
__backend__ (_backend): The backend storage for the cached results.
""" # noqa: E501
__wrapped__: Callable[_P, _R]
__duration__: datetime.timedelta
__backend__: _CacheBackendT
def __init__(
self,
func: Callable[_P, _R],
duration: datetime.timedelta,
) -> None:
self.__wrapped__ = func
self.__duration__ = duration
self.__backend__ = AbstractCacheBackend()
functools.update_wrapper(self, func)
def cache_clear(self) -> None:
"""Clears the cache for the wrapped function."""
self.__backend__.del_func_cache(func=self.__wrapped__)
def no_cache_call(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
"""
Calls the wrapped function without using the cache.
Args:
----
*args (_P.args): Positional arguments for the wrapped function.
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
Returns:
-------
_R: The result of the wrapped function.
"""
return self.__wrapped__(*args, **kwargs)
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
"""
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
Args:
----
*args (_P.args): Positional arguments for the wrapped function.
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
Returns:
-------
_R: The result of the wrapped function.
""" # noqa: E501
if "NO_CACHE" in os.environ:
return self.__wrapped__(*args, **kwargs)
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
return self.__backend__.get_cache_or_call(
func=self.__wrapped__,
args=args,
kwargs=kwargs,
lifespan=self.__duration__,
)
'''
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="__call__",
file_path=file_path,
parents=[FunctionParent(name="_PersistentCache", type="ClassDef")],
starting_line=None,
ending_line=None,
)
with open(file_path) as f:
original_code = f.read()
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, opt.args.project_root, original_code
)
expected_read_write_context = """
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
def get_cache_or_call(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
lifespan: datetime.timedelta,
) -> Any: # noqa: ANN401
\"\"\"
Retrieve the cached results for a function call.
Args:
----
func (Callable[..., _R]): The function to retrieve cached results for.
args (tuple[Any, ...]): The positional arguments passed to the function.
kwargs (dict[str, Any]): The keyword arguments passed to the function.
lifespan (datetime.timedelta): The maximum age of the cached results.
Returns:
-------
_R: The cached results, if available.
\"\"\"
if os.environ.get("NO_CACHE"):
return func(*args, **kwargs)
try:
key = self.hash_key(func=func, args=args, kwargs=kwargs)
except: # noqa: E722
# If we can't create a cache key, we should just call the function.
logging.warning("Failed to hash cache key for function: %s", func)
return func(*args, **kwargs)
result_pair = self.get(key=key)
if result_pair is not None:
cached_time, result = result_pair
if not os.environ.get("RE_CACHE") and (
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
):
try:
return self.decode(data=result)
except CacheBackendDecodeError as e:
logging.warning("Failed to decode cache data: %s", e)
# If decoding fails we will treat this as a cache miss.
# This might happens if underlying class definition of the data changes.
self.delete(key=key)
result = func(*args, **kwargs)
try:
self.put(key=key, data=self.encode(data=result))
except CacheBackendEncodeError as e:
logging.warning("Failed to encode cache data: %s", e)
# If encoding fails, we should still return the result.
return result
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
\"\"\"
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
Args:
----
*args (_P.args): Positional arguments for the wrapped function.
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
Returns:
-------
_R: The result of the wrapped function.
\"\"\" # noqa: E501
if "NO_CACHE" in os.environ:
return self.__wrapped__(*args, **kwargs)
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
return self.__backend__.get_cache_or_call(
func=self.__wrapped__,
args=args,
kwargs=kwargs,
lifespan=self.__duration__,
)
"""
expected_read_only_context = f'''
```python:{file_path}
_P = ParamSpec("_P")
_KEY_T = TypeVar("_KEY_T")
_STORE_T = TypeVar("_STORE_T")
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
"""Interface for cache backends used by the persistent cache decorator."""
def __init__(self) -> None: ...
_P = ParamSpec("_P")
_R = TypeVar("_R")
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
"""
A decorator class that provides persistent caching functionality for a function.
Args:
----
func (Callable[_P, _R]): The function to be decorated.
duration (datetime.timedelta): The duration for which the cached results should be considered valid.
backend (_backend): The backend storage for the cached results.
Attributes:
----------
__wrapped__ (Callable[_P, _R]): The wrapped function.
__duration__ (datetime.timedelta): The duration for which the cached results should be considered valid.
__backend__ (_backend): The backend storage for the cached results.
""" # noqa: E501
__wrapped__: Callable[_P, _R]
__duration__: datetime.timedelta
__backend__: _CacheBackendT
def __init__(
self,
func: Callable[_P, _R],
duration: datetime.timedelta,
) -> None:
self.__wrapped__ = func
self.__duration__ = duration
self.__backend__ = AbstractCacheBackend()
functools.update_wrapper(self, func)
```
'''
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()