added helpers of helpers into readonly context, and refactored code slightly

This commit is contained in:
Alvin Ryanputra 2024-12-31 17:19:09 -08:00
parent 641088fad5
commit a10b399dbe
6 changed files with 779 additions and 227 deletions

View file

@ -0,0 +1,27 @@
from code_to_optimize.code_directories.retriever.utils import DataProcessor
class DataTransformer:
def __init__(self):
self.data = None
def transform(self, data):
self.data = data
return self.data
def transform_using_own_method(self, data):
return self.transform(data)
def transform_using_same_file_function(self, data):
return update_data(data)
def transform_data_all_same_file(self, data):
new_data = update_data(data)
return self.transform_using_own_method(new_data)
def circular_dependency(self, data):
return DataProcessor().circular_dependency(data)
def update_data(data):
return data + " updated"

View file

@ -1,5 +1,9 @@
import math
from transform_utils import DataTransformer
GLOBAL_VAR = 10
class DataProcessor:
"""A class for processing data."""
@ -25,3 +29,19 @@ class DataProcessor:
def do_something(self):
print("something")
def transform_data(self, data: str) -> str:
"""Transform the processed data"""
return DataTransformer().transform(data)
def transform_data_own_method(self, data: str) -> str:
"""Transform the processed data using own method"""
return DataTransformer().transform_using_own_method(data)
def transform_data_same_file_function(self, data: str) -> str:
"""Transform the processed data using a function from the same file"""
return DataTransformer().transform_using_same_file_function(data)
def circular_dependency(self, data: str) -> str:
"""Test circular dependency"""
return DataTransformer().circular_dependency(data)

View file

@ -15,99 +15,36 @@ from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from codeflash.code_utils.code_utils import get_qualified_name, path_belongs_to_site_packages
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeString, CodeStringsMarkdown
from codeflash.optimization.function_context import belongs_to_class, belongs_to_function
from codeflash.optimization.function_context import belongs_to_function_qualified
def get_code_optimization_context(
function_to_optimize: FunctionToOptimize, project_root_path: Path, token_limit: int = 8000
) -> tuple[str, str]:
function_name = function_to_optimize.function_name
file_path = function_to_optimize.file_path
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
file_path_to_qualified_function_names = defaultdict(set)
file_path_to_qualified_function_names[file_path].add(function_to_optimize.qualified_name)
read_only_code_markdown = CodeStringsMarkdown()
final_read_writable_code = ""
names = []
for ref in script.get_names(all_scopes=True, definitions=False, references=True):
if ref.full_name:
if function_to_optimize.parents:
# Check if the reference belongs to the specified class when FunctionParent is provided
if belongs_to_class(ref, function_to_optimize.parents[-1].name) and belongs_to_function(
ref, function_name
):
names.append(ref)
elif belongs_to_function(ref, function_name):
names.append(ref)
# Get qualified names and fully qualified names(fqn) of helpers
helpers_of_fto, helpers_of_fto_fqn = get_file_path_to_helper_functions_dict(
{function_to_optimize.file_path: {function_to_optimize.qualified_name}}, project_root_path
)
helpers_of_helpers, helpers_of_helpers_fqn = get_file_path_to_helper_functions_dict(
helpers_of_fto, project_root_path
)
for name in names:
try:
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
except Exception as e:
try:
logger.exception(f"Error while getting definition for {name.full_name}: {e}")
except Exception as e:
# name.full_name can also throw exceptions sometimes
logger.exception(f"Error while getting definition: {e}")
definitions = []
if definitions:
# TODO: there can be multiple definitions, see how to handle such cases
definition = definitions[0]
definition_path = definition.module_path
# Add function to optimize
helpers_of_fto[function_to_optimize.file_path].add(function_to_optimize.qualified_name)
helpers_of_fto_fqn[function_to_optimize.file_path].add(
function_to_optimize.qualified_name_with_modules_from_root(project_root_path)
)
# The definition is part of this project and not defined within the original function
if (
str(definition_path).startswith(str(project_root_path) + os.sep)
and not path_belongs_to_site_packages(definition_path)
and definition.full_name
and not belongs_to_function(definition, function_name)
and definition.module_name != definition.full_name
):
file_path_to_qualified_function_names[definition_path].add(
get_qualified_name(definition.module_name, definition.full_name)
)
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
try:
og_code_containing_helpers = file_path.read_text("utf8")
except Exception as e:
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
read_writable_code = get_read_writable_code(og_code_containing_helpers, qualified_function_names)
except ValueError as e:
logger.debug(f"Error while getting read-writable code: {e}")
continue
if read_writable_code:
final_read_writable_code += f"\n{read_writable_code}"
final_read_writable_code = add_needed_imports_from_module(
src_module_code=og_code_containing_helpers,
dst_module_code=final_read_writable_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=qualified_function_names,
)
try:
read_only_code = get_read_only_code(og_code_containing_helpers, qualified_function_names)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
read_only_code_with_imports = CodeString(
code=add_needed_imports_from_module(
src_module_code=og_code_containing_helpers,
dst_module_code=read_only_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=qualified_function_names,
),
file_path=Path(file_path),
)
if read_only_code_with_imports.code:
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
# Extract code
final_read_writable_code = get_all_read_writable_code(helpers_of_fto, helpers_of_fto_fqn, project_root_path)
read_only_code_markdown = get_all_read_only_code_context(
helpers_of_fto,
helpers_of_fto_fqn,
helpers_of_helpers,
helpers_of_helpers_fqn,
project_root_path,
remove_docstrings=False,
)
# Handle token limits
tokenizer = tiktoken.encoding_for_model("gpt-4o")
@ -121,12 +58,85 @@ def get_code_optimization_context(
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
# Get read-only code context again, this time without docstrings
# Extract read only code without docstrings
read_only_code_no_docstring_markdown = get_all_read_only_code_context(
helpers_of_fto,
helpers_of_fto_fqn,
helpers_of_helpers,
helpers_of_helpers_fqn,
project_root_path,
remove_docstrings=True,
)
read_only_code_no_docstring_markdown_tokens = len(tokenizer.encode(read_only_code_no_docstring_markdown.markdown))
total_tokens = final_read_writable_tokens + read_only_code_no_docstring_markdown_tokens
if total_tokens <= token_limit:
return CodeString(code=final_read_writable_code).code, read_only_code_no_docstring_markdown.markdown
logger.debug("Code context has exceeded token limit, removing read-only code")
return CodeString(code=final_read_writable_code).code, ""
def get_all_read_writable_code(
helpers_of_fto: dict[Path, set[str]], helpers_of_fto_fqn: dict[Path, set[str]], project_root_path: Path
) -> str:
final_read_writable_code = ""
# Extract code from file paths that contain fto and first degree helpers
for file_path, qualified_function_names in helpers_of_fto.items():
try:
original_code = file_path.read_text("utf8")
except Exception as e:
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
read_writable_code = get_read_writable_code(original_code, qualified_function_names)
except ValueError as e:
logger.debug(f"Error while getting read-writable code: {e}")
continue
if read_writable_code:
final_read_writable_code += f"\n{read_writable_code}"
final_read_writable_code = add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=final_read_writable_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=helpers_of_fto_fqn[file_path],
)
return final_read_writable_code
def get_all_read_only_code_context(
helpers_of_fto: dict[Path, set[str]],
helpers_of_fto_fqn: dict[Path, set[str]],
helpers_of_helpers: dict[Path, set[str]],
helpers_of_helpers_fqn: dict[Path, set[str]],
project_root_path: Path,
remove_docstrings: bool = False,
) -> CodeStringsMarkdown:
# Rearrange to remove overlaps, so we only access each file path once
helpers_of_helpers_no_overlap = defaultdict(set)
helpers_of_helpers_no_overlap_fqn = defaultdict(set)
for file_path in helpers_of_helpers:
if file_path in helpers_of_fto:
# Remove duplicates, in case a helper of helper is also a helper of fto
helpers_of_helpers[file_path] -= helpers_of_fto[file_path]
helpers_of_helpers_fqn[file_path] -= helpers_of_fto_fqn[file_path]
else:
helpers_of_helpers_no_overlap[file_path] = helpers_of_helpers[file_path]
helpers_of_helpers_no_overlap_fqn[file_path] = helpers_of_helpers_fqn[file_path]
read_only_code_markdown = CodeStringsMarkdown()
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
# Extract code from file paths that contain fto and first degree helpers. helpers of helpers may also be included if they are in the same files
for file_path, qualified_function_names in helpers_of_fto.items():
try:
original_code = file_path.read_text("utf8")
except Exception as e:
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
read_only_code = get_read_only_code(
og_code_containing_helpers, qualified_function_names, remove_docstrings=True
original_code, qualified_function_names, helpers_of_helpers.get(file_path, set()), remove_docstrings
)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
@ -134,24 +144,93 @@ def get_code_optimization_context(
read_only_code_with_imports = CodeString(
code=add_needed_imports_from_module(
src_module_code=og_code_containing_helpers,
src_module_code=original_code,
dst_module_code=read_only_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=qualified_function_names,
helper_functions_fqn=helpers_of_fto_fqn[file_path] | helpers_of_helpers_fqn[file_path],
),
file_path=Path(file_path),
)
if read_only_code_with_imports.code:
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_code_markdown.markdown))
total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens
if total_tokens <= token_limit:
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
if read_only_code_with_imports.code:
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
logger.debug("Code context has exceeded token limit, removing read-only code")
return CodeString(code=final_read_writable_code).code, ""
# Extract code from file paths containing helpers of helpers
for file_path, qualified_helper_function_names in helpers_of_helpers_no_overlap.items():
try:
original_code = file_path.read_text("utf8")
except Exception as e:
logger.exception(f"Error while parsing {file_path}: {e}")
continue
try:
read_only_code = get_read_only_code(
original_code, set(), qualified_helper_function_names, remove_docstrings
)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
read_only_code_with_imports = CodeString(
code=add_needed_imports_from_module(
src_module_code=original_code,
dst_module_code=read_only_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=helpers_of_helpers_no_overlap_fqn[file_path],
),
file_path=Path(file_path),
)
if read_only_code_with_imports.code:
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
return read_only_code_markdown
def get_file_path_to_helper_functions_dict(
file_path_to_qualified_function_names: dict[Path, set[str]], project_root_path: Path
) -> tuple[dict[Path, set[str]], dict[Path, set[str]]]:
file_path_to_helper_function_qualified_names = defaultdict(set)
file_path_to_helper_function_fqn = defaultdict(set)
for file_path in file_path_to_qualified_function_names:
script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path))
file_refs = script.get_names(all_scopes=True, definitions=False, references=True)
for qualified_function_name in file_path_to_qualified_function_names[file_path]:
names = [
ref
for ref in file_refs
if ref.full_name and belongs_to_function_qualified(ref, qualified_function_name)
]
for name in names:
try:
definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False)
except Exception as e:
try:
logger.exception(f"Error while getting definition for {name.full_name}: {e}")
except Exception as e:
# name.full_name can also throw exceptions sometimes
logger.exception(f"Error while getting definition: {e}")
definitions = []
if definitions:
# TODO: there can be multiple definitions, see how to handle such cases
definition = definitions[0]
definition_path = definition.module_path
# The definition is part of this project and not defined within the original function
if (
str(definition_path).startswith(str(project_root_path) + os.sep)
and not path_belongs_to_site_packages(definition_path)
and definition.full_name
and definition.type == "function"
and not belongs_to_function_qualified(definition, qualified_function_name)
):
file_path_to_helper_function_qualified_names[definition_path].add(
get_qualified_name(definition.module_name, definition.full_name)
)
file_path_to_helper_function_fqn[definition_path].add(definition.full_name)
return file_path_to_helper_function_qualified_names, file_path_to_helper_function_fqn
def is_dunder_method(name: str) -> bool:
@ -166,7 +245,6 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
"""Removes the docstring from an indented block if it exists"""
print(indented_block)
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
return indented_block
first_stmt = indented_block.body[0].body[0]
@ -268,7 +346,11 @@ def get_read_writable_code(code: str, target_functions: set[str]) -> str:
def prune_cst_for_read_only_code(
node: cst.CSTNode, target_functions: set[str], prefix: str = "", remove_docstrings: bool = False
node: cst.CSTNode,
target_functions: set[str],
helpers_of_helper_functions: set[str],
prefix: str = "",
remove_docstrings: bool = False,
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node for read-only context:
@ -284,6 +366,8 @@ def prune_cst_for_read_only_code(
if isinstance(node, cst.FunctionDef):
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
# If it's a target function, remove it but mark found_target = True
if qualified_name in helpers_of_helper_functions:
return node, True
if qualified_name in target_functions:
return None, True
# Keep only dunder methods
@ -309,15 +393,9 @@ def prune_cst_for_read_only_code(
new_class_body: list[CSTNode] = []
for stmt in node.body.body:
filtered, found_target = prune_cst_for_read_only_code(
stmt, target_functions, class_prefix, remove_docstrings=remove_docstrings
stmt, target_functions, helpers_of_helper_functions, class_prefix, remove_docstrings=remove_docstrings
)
found_in_class |= found_target
if isinstance(filtered, cst.FunctionDef):
# Check if it's a target or non-dunder method
qname = f"{class_prefix}.{filtered.name.value}"
if qname in target_functions or not is_dunder_method(filtered.name.value):
continue
if filtered:
new_class_body.append(filtered)
@ -345,7 +423,7 @@ def prune_cst_for_read_only_code(
section_found_target = False
for child in original_content:
filtered, found_target = prune_cst_for_read_only_code(
child, target_functions, prefix, remove_docstrings=remove_docstrings
child, target_functions, helpers_of_helper_functions, prefix, remove_docstrings=remove_docstrings
)
if filtered:
new_children.append(filtered)
@ -356,25 +434,30 @@ def prune_cst_for_read_only_code(
updates[section] = new_children
elif original_content is not None:
filtered, found_target = prune_cst_for_read_only_code(
original_content, target_functions, prefix, remove_docstrings=remove_docstrings
original_content,
target_functions,
helpers_of_helper_functions,
prefix,
remove_docstrings=remove_docstrings,
)
found_any_target |= found_target
if filtered:
updates[section] = filtered
if updates:
return (node.with_changes(**updates), found_any_target)
return node, found_any_target
return None, False
def get_read_only_code(code: str, target_functions: set[str], remove_docstrings: bool = False) -> str:
def get_read_only_code(
code: str, target_functions: set[str], helpers_of_helper_functions: set[str], remove_docstrings: bool = False
) -> str:
"""Creates a read-only version of the code by parsing and filtering the code to keep only
class contextual information, and other module scoped variables.
"""
module = cst.parse_module(code)
filtered_node, found_target = prune_cst_for_read_only_code(
module, target_functions, remove_docstrings=remove_docstrings
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
)
if not found_target:
raise ValueError("No target functions found in the provided code")

View file

@ -12,7 +12,11 @@ from jedi.api.classes import Name
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import get_code
from codeflash.code_utils.code_utils import module_name_from_file_path, path_belongs_to_site_packages
from codeflash.code_utils.code_utils import (
get_qualified_name,
module_name_from_file_path,
path_belongs_to_site_packages,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent, FunctionSource
@ -26,7 +30,7 @@ def belongs_to_method(name: Name, class_name: str, method_name: str) -> bool:
def belongs_to_function(name: Name, function_name: str) -> bool:
"""Check if the given jedi Name is a direct child of the specified function"""
"""Check if the given jedi Name is a direct child of the specified function."""
if name.name == function_name: # Handles function definition and recursive function calls
return False
if name := name.parent():
@ -36,13 +40,28 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
def belongs_to_class(name: Name, class_name: str) -> bool:
"""Check if given jedi Name is a direct child of the specified class"""
"""Check if given jedi Name is a direct child of the specified class."""
while name := name.parent():
if name.type == "class":
return name.name == class_name
return False
def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> bool:
"""Check if the given jedi Name is a direct child of the specified function, matched by qualified function name."""
try:
if get_qualified_name(name.module_name, name.full_name) == qualified_function_name:
# Handles function definition and recursive function calls
return False
if name := name.parent():
if name.type == "function":
return get_qualified_name(name.module_name, name.full_name) == qualified_function_name
return False
except ValueError as e:
logger.exception(f"Error while checking if {name.full_name} belongs to {qualified_function_name}: {e}")
return False
def get_type_annotation_context(
function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: Path
) -> tuple[list[FunctionSource], set[tuple[str, str]]]:

View file

@ -24,6 +24,10 @@ class HelperClass:
return self.name
def main_method():
return "hello"
class MainClass:
def __init__(self, name):
self.name = name
@ -67,7 +71,7 @@ def test_code_replacement10() -> None:
func_top_optimize = FunctionToOptimize(
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
)
original_code = file_path.read_text()
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize=func_top_optimize, project_root_path=file_path.parent
)
@ -90,7 +94,6 @@ def test_code_replacement10() -> None:
```python:{file_path}
from __future__ import annotations
class HelperClass:
def __init__(self, name):
self.name = name
@ -151,7 +154,6 @@ class Graph:
from __future__ import annotations
from collections import defaultdict
class Graph:
def __init__(self, vertices):
self.graph = defaultdict(list)
@ -184,14 +186,8 @@ def test_bubble_sort_helper() -> None:
)
expected_read_write_context = """
from bubble_sort_with_math import sorter
import math
def sort_from_another_file(arr):
sorted_arr = sorter(arr)
return sorted_arr
from bubble_sort_with_math import sorter
def sorter(arr):
arr.sort()
@ -199,6 +195,12 @@ def sorter(arr):
print(x)
return arr
def sort_from_another_file(arr):
sorted_arr = sorter(arr)
return sorted_arr
"""
expected_read_only_context = ""
@ -206,84 +208,6 @@ def sorter(arr):
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
def test_repo_helper() -> None:
path_to_file = (
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "main.py"
)
path_to_utils = (
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "utils.py"
)
function_to_optimize = FunctionToOptimize(
function_name="fetch_and_process_data",
file_path=str(path_to_file),
parents=[],
starting_line=None,
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, Path(__file__).resolve().parent.parent
)
expected_read_write_context = """
import requests
from globals import API_URL
from utils import DataProcessor
def fetch_and_process_data():
# Use the global variable for the request
response = requests.get(API_URL)
response.raise_for_status()
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
processed = processor.add_prefix(processed)
return processed
class DataProcessor:
def process_data(self, raw_data: str) -> str:
\"\"\"Process raw data by converting it to uppercase.\"\"\"
return raw_data.upper()
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
\"\"\"Add a prefix to the processed data.\"\"\"
return prefix + data
"""
expected_read_only_context = f"""
```python:{path_to_file}
if __name__ == "__main__":
result = fetch_and_process_data()
print("Processed data:", result)
```
```python:{path_to_utils}
import math
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
number = 1
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
"""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
def test_flavio_typed_code_helper() -> None:
code = '''
@ -569,7 +493,27 @@ class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
"""Interface for cache backends used by the persistent cache decorator."""
def __init__(self) -> None: ...
def hash_key(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> tuple[str, _KEY_T]: ...
def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401
...
def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401
...
def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ...
def delete(self, *, key: tuple[str, _KEY_T]) -> None: ...
def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ...
_P = ParamSpec("_P")
_R = TypeVar("_R")
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
@ -885,3 +829,462 @@ class HelperClass:
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, opt.args.project_root
)
def test_repo_helper() -> None:
path_to_file = (
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "main.py"
)
path_to_utils = (
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "utils.py"
)
function_to_optimize = FunctionToOptimize(
function_name="fetch_and_process_data",
file_path=str(path_to_file),
parents=[],
starting_line=None,
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize,
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever",
)
expected_read_write_context = """
import requests
from globals import API_URL
from utils import DataProcessor
class DataProcessor:
def process_data(self, raw_data: str) -> str:
\"\"\"Process raw data by converting it to uppercase.\"\"\"
return raw_data.upper()
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
\"\"\"Add a prefix to the processed data.\"\"\"
return prefix + data
def fetch_and_process_data():
# Use the global variable for the request
response = requests.get(API_URL)
response.raise_for_status()
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
processed = processor.add_prefix(processed)
return processed
"""
expected_read_only_context = f"""
```python:{path_to_utils}
import math
GLOBAL_VAR = 10
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
number = 1
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
```python:{path_to_file}
if __name__ == "__main__":
result = fetch_and_process_data()
print("Processed data:", result)
```
"""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
def test_repo_helper_of_helper() -> None:
path_to_file = (
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "main.py"
)
path_to_utils = (
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "utils.py"
)
path_to_transform_utils = (
Path(__file__).resolve().parent.parent
/ "code_to_optimize"
/ "code_directories"
/ "retriever"
/ "transform_utils.py"
)
function_to_optimize = FunctionToOptimize(
function_name="fetch_and_transform_data",
file_path=str(path_to_file),
parents=[],
starting_line=None,
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize,
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever",
)
expected_read_write_context = """
from transform_utils import DataTransformer
import requests
from globals import API_URL
from utils import DataProcessor
class DataProcessor:
def process_data(self, raw_data: str) -> str:
\"\"\"Process raw data by converting it to uppercase.\"\"\"
return raw_data.upper()
def transform_data(self, data: str) -> str:
\"\"\"Transform the processed data\"\"\"
return DataTransformer().transform(data)
def fetch_and_transform_data():
# Use the global variable for the request
response = requests.get(API_URL)
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
transformed = processor.transform_data(processed)
return transformed
"""
expected_read_only_context = f"""
```python:{path_to_utils}
import math
GLOBAL_VAR = 10
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
number = 1
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
```python:{path_to_file}
if __name__ == "__main__":
result = fetch_and_process_data()
print("Processed data:", result)
```
```python:{path_to_transform_utils}
class DataTransformer:
def __init__(self):
self.data = None
def transform(self, data):
self.data = data
return self.data
```
"""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
def test_repo_helper_of_helper_same_class() -> None:
path_to_utils = (
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "utils.py"
)
path_to_transform_utils = (
Path(__file__).resolve().parent.parent
/ "code_to_optimize"
/ "code_directories"
/ "retriever"
/ "transform_utils.py"
)
function_to_optimize = FunctionToOptimize(
function_name="transform_data_own_method",
file_path=str(path_to_utils),
parents=[FunctionParent(name="DataProcessor", type="ClassDef")],
starting_line=None,
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize,
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever",
)
expected_read_write_context = """
from transform_utils import DataTransformer
class DataTransformer:
def transform_using_own_method(self, data):
return self.transform(data)
class DataProcessor:
def transform_data_own_method(self, data: str) -> str:
\"\"\"Transform the processed data using own method\"\"\"
return DataTransformer().transform_using_own_method(data)
"""
expected_read_only_context = f"""
```python:{path_to_transform_utils}
class DataTransformer:
def __init__(self):
self.data = None
def transform(self, data):
self.data = data
return self.data
```
```python:{path_to_utils}
import math
GLOBAL_VAR = 10
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
number = 1
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
"""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
def test_repo_helper_of_helper_same_file() -> None:
path_to_utils = (
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "utils.py"
)
path_to_transform_utils = (
Path(__file__).resolve().parent.parent
/ "code_to_optimize"
/ "code_directories"
/ "retriever"
/ "transform_utils.py"
)
function_to_optimize = FunctionToOptimize(
function_name="transform_data_same_file_function",
file_path=str(path_to_utils),
parents=[FunctionParent(name="DataProcessor", type="ClassDef")],
starting_line=None,
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize,
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever",
)
expected_read_write_context = """
from transform_utils import DataTransformer
class DataTransformer:
def transform_using_same_file_function(self, data):
return update_data(data)
class DataProcessor:
def transform_data_same_file_function(self, data: str) -> str:
\"\"\"Transform the processed data using a function from the same file\"\"\"
return DataTransformer().transform_using_same_file_function(data)
"""
expected_read_only_context = f"""
```python:{path_to_transform_utils}
class DataTransformer:
def __init__(self):
self.data = None
def update_data(data):
return data + " updated"
```
```python:{path_to_utils}
import math
GLOBAL_VAR = 10
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
number = 1
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
"""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
def test_repo_helper_all_same_file() -> None:
path_to_transform_utils = (
Path(__file__).resolve().parent.parent
/ "code_to_optimize"
/ "code_directories"
/ "retriever"
/ "transform_utils.py"
)
function_to_optimize = FunctionToOptimize(
function_name="transform_data_all_same_file",
file_path=str(path_to_transform_utils),
parents=[FunctionParent(name="DataTransformer", type="ClassDef")],
starting_line=None,
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize,
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever",
)
expected_read_write_context = """
class DataTransformer:
def transform_using_own_method(self, data):
return self.transform(data)
def transform_data_all_same_file(self, data):
new_data = update_data(data)
return self.transform_using_own_method(new_data)
def update_data(data):
return data + " updated"
"""
expected_read_only_context = f"""
```python:{path_to_transform_utils}
class DataTransformer:
def __init__(self):
self.data = None
def transform(self, data):
self.data = data
return self.data
```
"""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
def test_repo_helper_circular_dependency() -> None:
path_to_transform_utils = (
Path(__file__).resolve().parent.parent
/ "code_to_optimize"
/ "code_directories"
/ "retriever"
/ "transform_utils.py"
)
path_to_utils = (
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever" / "utils.py"
)
function_to_optimize = FunctionToOptimize(
function_name="circular_dependency",
file_path=str(path_to_transform_utils),
parents=[FunctionParent(name="DataTransformer", type="ClassDef")],
starting_line=None,
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize,
Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever",
)
expected_read_write_context = """
from transform_utils import DataTransformer
from code_to_optimize.code_directories.retriever.utils import DataProcessor
class DataProcessor:
def circular_dependency(self, data: str) -> str:
\"\"\"Test circular dependency\"\"\"
return DataTransformer().circular_dependency(data)
class DataTransformer:
def circular_dependency(self, data):
return DataProcessor().circular_dependency(data)
"""
expected_read_only_context = f"""
```python:{path_to_utils}
import math
GLOBAL_VAR = 10
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
number = 1
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
```python:{path_to_transform_utils}
class DataTransformer:
def __init__(self):
self.data = None
```
"""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()

View file

@ -21,7 +21,7 @@ def test_basic_class() -> None:
class_var = "value"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -47,7 +47,7 @@ def test_dunder_methods() -> None:
return f"Value: {self.x}"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -75,7 +75,7 @@ def test_dunder_methods_remove_docstring() -> None:
return f"Value: {self.x}"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, remove_docstrings=True)
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True)
assert dedent(expected).strip() == output.strip()
@ -102,7 +102,7 @@ def test_class_remove_docstring() -> None:
return f"Value: {self.x}"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, remove_docstrings=True)
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True)
assert dedent(expected).strip() == output.strip()
@ -131,7 +131,7 @@ def test_mixed_remove_docstring() -> None:
return f"Value: {self.x}"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, remove_docstrings=True)
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set(), remove_docstrings=True)
assert dedent(expected).strip() == output.strip()
@ -149,7 +149,7 @@ def test_target_in_nested_class() -> None:
"""
with pytest.raises(ValueError, match="No target functions found in the provided code"):
get_read_only_code(dedent(code), {"Outer.Inner.target_method"})
get_read_only_code(dedent(code), {"Outer.Inner.target_method"}, set())
def test_docstrings() -> None:
@ -171,7 +171,7 @@ def test_docstrings() -> None:
\"\"\"Class docstring.\"\"\"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -190,7 +190,7 @@ def test_method_signatures() -> None:
expected = """"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -214,7 +214,7 @@ def test_multiple_top_level_targets() -> None:
self.x = 42
"""
output = get_read_only_code(dedent(code), {"TestClass.target1", "TestClass.target2"})
output = get_read_only_code(dedent(code), {"TestClass.target1", "TestClass.target2"}, set())
assert dedent(expected).strip() == output.strip()
@ -234,7 +234,7 @@ def test_class_annotations() -> None:
var2: str
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -256,7 +256,7 @@ def test_class_annotations_if() -> None:
var2: str
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -282,7 +282,7 @@ def test_class_annotations_try() -> None:
continue
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -318,7 +318,7 @@ def test_class_annotations_else() -> None:
var2: str
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -333,7 +333,7 @@ def test_top_level_functions() -> None:
expected = """"""
output = get_read_only_code(dedent(code), {"target_function"})
output = get_read_only_code(dedent(code), {"target_function"}, set())
assert dedent(expected).strip() == output.strip()
@ -352,7 +352,7 @@ def test_module_var() -> None:
x = 5
"""
output = get_read_only_code(dedent(code), {"target_function"})
output = get_read_only_code(dedent(code), {"target_function"}, set())
assert dedent(expected).strip() == output.strip()
@ -379,7 +379,7 @@ def test_module_var_if() -> None:
z = 10
"""
output = get_read_only_code(dedent(code), {"target_function"})
output = get_read_only_code(dedent(code), {"target_function"}, set())
assert dedent(expected).strip() == output.strip()
@ -414,7 +414,7 @@ def test_conditional_class_definitions() -> None:
platform = "other"
"""
output = get_read_only_code(dedent(code), {"PlatformClass.target_method"})
output = get_read_only_code(dedent(code), {"PlatformClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -473,7 +473,7 @@ def test_multiple_except_clauses() -> None:
error_type = "cleanup"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -526,7 +526,7 @@ def test_with_statement_and_loops() -> None:
context = "cleanup"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -575,7 +575,7 @@ def test_async_with_try_except() -> None:
status = "cancelled"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"})
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -685,7 +685,7 @@ def test_simplified_complete_implementation() -> None:
self.error = str(e)
"""
output = get_read_only_code(dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"})
output = get_read_only_code(dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, set())
assert dedent(expected).strip() == output.strip()
@ -795,6 +795,6 @@ def test_simplified_complete_implementation_no_docstring() -> None:
"""
output = get_read_only_code(
dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, remove_docstrings=True
dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, set(), remove_docstrings=True
)
assert dedent(expected).strip() == output.strip()