codeflash/tests/test_code_context_extractor.py

2134 lines
68 KiB
Python
Raw Normal View History

from __future__ import annotations
import tempfile
from argparse import Namespace
from collections import defaultdict
from pathlib import Path
2024-12-26 22:29:32 +00:00
import pytest
2024-12-26 22:06:05 +00:00
from codeflash.context.code_context_extractor import get_code_optimization_context
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.optimization.optimizer import Optimizer
class HelperClass:
def __init__(self, name):
self.name = name
def innocent_bystander(self):
pass
def helper_method(self):
return self.name
class NestedClass:
def __init__(self, name):
self.name = name
def nested_method(self):
return self.name
2025-06-08 07:30:47 +00:00
def main_method():
return "hello"
class MainClass:
def __init__(self, name):
self.name = name
def main_method(self):
self.name = HelperClass.NestedClass("test").nested_method()
return HelperClass(self.name).helper_method()
class Graph:
def __init__(self, vertices):
self.graph = defaultdict(list)
self.V = vertices # No. of vertices
def addEdge(self, u, v):
self.graph[u].append(v)
def topologicalSortUtil(self, v, visited, stack):
visited[v] = True
for i in self.graph[v]:
if visited[i] == False:
self.topologicalSortUtil(i, visited, stack)
stack.insert(0, v)
def topologicalSort(self):
visited = [False] * self.V
stack = []
for i in range(self.V):
if visited[i] == False:
self.topologicalSortUtil(i, visited, stack)
# Print contents of stack
return stack
def test_code_replacement10() -> None:
file_path = Path(__file__).resolve()
func_top_optimize = FunctionToOptimize(
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
)
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
2025-06-08 07:30:47 +00:00
assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
from __future__ import annotations
class HelperClass:
def __init__(self, name):
self.name = name
def helper_method(self):
return self.name
class MainClass:
def __init__(self, name):
self.name = name
def main_method(self):
self.name = HelperClass.NestedClass("test").nested_method()
return HelperClass(self.name).helper_method()
"""
expected_read_only_context = """
"""
2025-06-08 07:30:47 +00:00
expected_hashing_context = f"""
```python:{file_path.relative_to(file_path.parent)}
class HelperClass:
def helper_method(self):
return self.name
class MainClass:
def main_method(self):
self.name = HelperClass.NestedClass("test").nested_method()
return HelperClass(self.name).helper_method()
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
def test_class_method_dependencies() -> None:
file_path = Path(__file__).resolve()
function_to_optimize = FunctionToOptimize(
function_name="topologicalSort",
file_path=str(file_path),
parents=[FunctionParent(name="Graph", type="ClassDef")],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, file_path.parent.resolve())
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
from __future__ import annotations
from collections import defaultdict
class Graph:
def __init__(self, vertices):
self.graph = defaultdict(list)
self.V = vertices # No. of vertices
def topologicalSortUtil(self, v, visited, stack):
visited[v] = True
for i in self.graph[v]:
if visited[i] == False:
self.topologicalSortUtil(i, visited, stack)
stack.insert(0, v)
def topologicalSort(self):
visited = [False] * self.V
stack = []
for i in range(self.V):
if visited[i] == False:
self.topologicalSortUtil(i, visited, stack)
# Print contents of stack
return stack
"""
expected_read_only_context = ""
2025-06-08 07:30:47 +00:00
expected_hashing_context = f"""
```python:{file_path.relative_to(file_path.parent.resolve())}
class Graph:
def topologicalSortUtil(self, v, visited, stack):
visited[v] = True
for i in self.graph[v]:
if visited[i] == False:
self.topologicalSortUtil(i, visited, stack)
stack.insert(0, v)
def topologicalSort(self):
visited = [False] * self.V
stack = []
for i in range(self.V):
if visited[i] == False:
self.topologicalSortUtil(i, visited, stack)
# Print contents of stack
return stack
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
def test_bubble_sort_helper() -> None:
path_to_fto = (
Path(__file__).resolve().parent.parent
/ "code_to_optimize"
/ "code_directories"
/ "retriever"
/ "bubble_sort_imported.py"
)
function_to_optimize = FunctionToOptimize(
function_name="sort_from_another_file",
file_path=str(path_to_fto),
parents=[],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, Path(__file__).resolve().parent.parent)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
import math
from bubble_sort_with_math import sorter
def sorter(arr):
arr.sort()
x = math.sqrt(2)
print(x)
return arr
def sort_from_another_file(arr):
sorted_arr = sorter(arr)
return sorted_arr
"""
expected_read_only_context = ""
2025-06-08 07:30:47 +00:00
expected_hashing_context = """
```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py
def sorter(arr):
arr.sort()
x = math.sqrt(2)
print(x)
return arr
```
```python:code_to_optimize/code_directories/retriever/bubble_sort_imported.py
def sort_from_another_file(arr):
sorted_arr = sorter(arr)
return sorted_arr
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
def test_flavio_typed_code_helper() -> None:
code = '''
_P = ParamSpec("_P")
_KEY_T = TypeVar("_KEY_T")
_STORE_T = TypeVar("_STORE_T")
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
"""Interface for cache backends used by the persistent cache decorator."""
def __init__(self) -> None: ...
def hash_key(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> tuple[str, _KEY_T]: ...
def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401
...
def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401
...
def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ...
def delete(self, *, key: tuple[str, _KEY_T]) -> None: ...
def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ...
def get_cache_or_call(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
lifespan: datetime.timedelta,
) -> Any: # noqa: ANN401
"""
Retrieve the cached results for a function call.
Args:
----
func (Callable[..., _R]): The function to retrieve cached results for.
args (tuple[Any, ...]): The positional arguments passed to the function.
kwargs (dict[str, Any]): The keyword arguments passed to the function.
lifespan (datetime.timedelta): The maximum age of the cached results.
Returns:
-------
_R: The cached results, if available.
"""
if os.environ.get("NO_CACHE"):
return func(*args, **kwargs)
try:
key = self.hash_key(func=func, args=args, kwargs=kwargs)
except: # noqa: E722
# If we can't create a cache key, we should just call the function.
logging.warning("Failed to hash cache key for function: %s", func)
return func(*args, **kwargs)
result_pair = self.get(key=key)
if result_pair is not None:
cached_time, result = result_pair
if not os.environ.get("RE_CACHE") and (
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
):
try:
return self.decode(data=result)
except CacheBackendDecodeError as e:
logging.warning("Failed to decode cache data: %s", e)
# If decoding fails we will treat this as a cache miss.
# This might happens if underlying class definition of the data changes.
self.delete(key=key)
result = func(*args, **kwargs)
try:
self.put(key=key, data=self.encode(data=result))
except CacheBackendEncodeError as e:
logging.warning("Failed to encode cache data: %s", e)
# If encoding fails, we should still return the result.
return result
_P = ParamSpec("_P")
_R = TypeVar("_R")
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
"""
A decorator class that provides persistent caching functionality for a function.
Args:
----
func (Callable[_P, _R]): The function to be decorated.
duration (datetime.timedelta): The duration for which the cached results should be considered valid.
backend (_backend): The backend storage for the cached results.
Attributes:
----------
__wrapped__ (Callable[_P, _R]): The wrapped function.
__duration__ (datetime.timedelta): The duration for which the cached results should be considered valid.
__backend__ (_backend): The backend storage for the cached results.
""" # noqa: E501
__wrapped__: Callable[_P, _R]
__duration__: datetime.timedelta
__backend__: _CacheBackendT
def __init__(
self,
func: Callable[_P, _R],
duration: datetime.timedelta,
) -> None:
self.__wrapped__ = func
self.__duration__ = duration
self.__backend__ = AbstractCacheBackend()
functools.update_wrapper(self, func)
def cache_clear(self) -> None:
"""Clears the cache for the wrapped function."""
self.__backend__.del_func_cache(func=self.__wrapped__)
def no_cache_call(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
"""
Calls the wrapped function without using the cache.
Args:
----
*args (_P.args): Positional arguments for the wrapped function.
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
Returns:
-------
_R: The result of the wrapped function.
"""
return self.__wrapped__(*args, **kwargs)
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
"""
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
Args:
----
*args (_P.args): Positional arguments for the wrapped function.
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
Returns:
-------
_R: The result of the wrapped function.
""" # noqa: E501
if "NO_CACHE" in os.environ:
return self.__wrapped__(*args, **kwargs)
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
return self.__backend__.get_cache_or_call(
func=self.__wrapped__,
args=args,
kwargs=kwargs,
lifespan=self.__duration__,
)
'''
2025-06-08 07:30:47 +00:00
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="__call__",
file_path=file_path,
parents=[FunctionParent(name="_PersistentCache", type="ClassDef")],
starting_line=None,
ending_line=None,
)
2024-12-19 22:02:18 +00:00
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
def __init__(self) -> None: ...
def get_cache_or_call(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
lifespan: datetime.timedelta,
) -> Any: # noqa: ANN401
\"\"\"
Retrieve the cached results for a function call.
Args:
----
func (Callable[..., _R]): The function to retrieve cached results for.
args (tuple[Any, ...]): The positional arguments passed to the function.
kwargs (dict[str, Any]): The keyword arguments passed to the function.
lifespan (datetime.timedelta): The maximum age of the cached results.
Returns:
-------
_R: The cached results, if available.
\"\"\"
if os.environ.get("NO_CACHE"):
return func(*args, **kwargs)
try:
key = self.hash_key(func=func, args=args, kwargs=kwargs)
except: # noqa: E722
# If we can't create a cache key, we should just call the function.
logging.warning("Failed to hash cache key for function: %s", func)
return func(*args, **kwargs)
result_pair = self.get(key=key)
if result_pair is not None:
cached_time, result = result_pair
if not os.environ.get("RE_CACHE") and (
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
):
try:
return self.decode(data=result)
except CacheBackendDecodeError as e:
logging.warning("Failed to decode cache data: %s", e)
# If decoding fails we will treat this as a cache miss.
# This might happens if underlying class definition of the data changes.
self.delete(key=key)
result = func(*args, **kwargs)
try:
self.put(key=key, data=self.encode(data=result))
except CacheBackendEncodeError as e:
logging.warning("Failed to encode cache data: %s", e)
# If encoding fails, we should still return the result.
return result
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
def __init__(
self,
func: Callable[_P, _R],
duration: datetime.timedelta,
) -> None:
self.__wrapped__ = func
self.__duration__ = duration
self.__backend__ = AbstractCacheBackend()
functools.update_wrapper(self, func)
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
\"\"\"
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
Args:
----
*args (_P.args): Positional arguments for the wrapped function.
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
Returns:
-------
_R: The result of the wrapped function.
\"\"\" # noqa: E501
if "NO_CACHE" in os.environ:
return self.__wrapped__(*args, **kwargs)
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
return self.__backend__.get_cache_or_call(
func=self.__wrapped__,
args=args,
kwargs=kwargs,
lifespan=self.__duration__,
)
"""
expected_read_only_context = f'''
```python:{file_path.relative_to(opt.args.project_root)}
_P = ParamSpec("_P")
_KEY_T = TypeVar("_KEY_T")
_STORE_T = TypeVar("_STORE_T")
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
"""Interface for cache backends used by the persistent cache decorator."""
def hash_key(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> tuple[str, _KEY_T]: ...
def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401
...
def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401
...
def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ...
def delete(self, *, key: tuple[str, _KEY_T]) -> None: ...
def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ...
_P = ParamSpec("_P")
_R = TypeVar("_R")
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
"""
A decorator class that provides persistent caching functionality for a function.
Args:
----
func (Callable[_P, _R]): The function to be decorated.
duration (datetime.timedelta): The duration for which the cached results should be considered valid.
backend (_backend): The backend storage for the cached results.
Attributes:
----------
__wrapped__ (Callable[_P, _R]): The wrapped function.
__duration__ (datetime.timedelta): The duration for which the cached results should be considered valid.
__backend__ (_backend): The backend storage for the cached results.
""" # noqa: E501
__wrapped__: Callable[_P, _R]
__duration__: datetime.timedelta
__backend__: _CacheBackendT
```
'''
2025-06-08 07:30:47 +00:00
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
def get_cache_or_call(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
lifespan: datetime.timedelta,
) -> Any: # noqa: ANN401
if os.environ.get("NO_CACHE"):
return func(*args, **kwargs)
try:
key = self.hash_key(func=func, args=args, kwargs=kwargs)
except: # noqa: E722
# If we can't create a cache key, we should just call the function.
logging.warning("Failed to hash cache key for function: %s", func)
return func(*args, **kwargs)
result_pair = self.get(key=key)
if result_pair is not None:
cached_time, result = result_pair
if not os.environ.get("RE_CACHE") and (
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
):
try:
return self.decode(data=result)
except CacheBackendDecodeError as e:
logging.warning("Failed to decode cache data: %s", e)
# If decoding fails we will treat this as a cache miss.
# This might happens if underlying class definition of the data changes.
self.delete(key=key)
result = func(*args, **kwargs)
try:
self.put(key=key, data=self.encode(data=result))
except CacheBackendEncodeError as e:
logging.warning("Failed to encode cache data: %s", e)
# If encoding fails, we should still return the result.
return result
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
if "NO_CACHE" in os.environ:
return self.__wrapped__(*args, **kwargs)
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
return self.__backend__.get_cache_or_call(
func=self.__wrapped__,
args=args,
kwargs=kwargs,
lifespan=self.__duration__,
)
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
2024-12-26 22:06:05 +00:00
def test_example_class() -> None:
code = """
class MyClass:
\"\"\"A class with a helper method.\"\"\"
def __init__(self):
self.x = 1
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
def helper_method(self):
return self.x
"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
2024-12-26 22:06:05 +00:00
expected_read_write_context = """
class MyClass:
def __init__(self):
self.x = 1
2024-12-26 22:06:05 +00:00
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
2024-12-26 22:06:05 +00:00
def helper_method(self):
return self.x
"""
expected_read_only_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
2024-12-26 22:06:05 +00:00
class MyClass:
\"\"\"A class with a helper method.\"\"\"
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
```
"""
2025-06-08 07:30:47 +00:00
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
def helper_method(self):
return self.x
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
2024-12-26 22:06:05 +00:00
def test_example_class_token_limit_1() -> None:
2024-12-26 22:29:32 +00:00
docstring_filler = " ".join(
["This is a long docstring that will be used to fill up the token limit." for _ in range(1000)]
)
2024-12-26 22:06:05 +00:00
code = f"""
class MyClass:
\"\"\"A class with a helper method.
2024-12-26 22:29:32 +00:00
{docstring_filler}\"\"\"
2024-12-26 22:06:05 +00:00
def __init__(self):
self.x = 1
def target_method(self):
2024-12-26 22:29:32 +00:00
\"\"\"Docstring for target method\"\"\"
2024-12-26 22:06:05 +00:00
y = HelperClass().helper_method()
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
def helper_method(self):
return self.x
"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
2024-12-26 22:29:32 +00:00
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
2024-12-26 22:06:05 +00:00
expected_read_write_context = """
class MyClass:
def __init__(self):
self.x = 1
2024-12-26 22:06:05 +00:00
def target_method(self):
2024-12-26 22:29:32 +00:00
\"\"\"Docstring for target method\"\"\"
2024-12-26 22:06:05 +00:00
y = HelperClass().helper_method()
class HelperClass:
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
2024-12-26 22:06:05 +00:00
def helper_method(self):
return self.x
"""
expected_read_only_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
2024-12-26 22:06:05 +00:00
class MyClass:
pass
2024-12-26 22:06:05 +00:00
2024-12-26 22:29:32 +00:00
class HelperClass:
def __repr__(self):
return "HelperClass" + str(self.x)
```
2025-06-08 07:30:47 +00:00
"""
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
def helper_method(self):
return self.x
```
2024-12-26 22:29:32 +00:00
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
2024-12-26 22:29:32 +00:00
def test_example_class_token_limit_2() -> None:
string_filler = " ".join(
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
)
code = f"""
class MyClass:
\"\"\"A class with a helper method. \"\"\"
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
x = '{string_filler}'
2024-12-26 22:06:05 +00:00
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
2024-12-26 22:29:32 +00:00
def helper_method(self):
return self.x
2024-12-26 22:06:05 +00:00
"""
2024-12-26 22:29:32 +00:00
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
2024-12-26 22:29:32 +00:00
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
expected_read_write_context = """
class MyClass:
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
class HelperClass:
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def helper_method(self):
return self.x
"""
2024-12-26 22:29:32 +00:00
expected_read_only_context = ""
2025-06-08 07:30:47 +00:00
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
def helper_method(self):
return self.x
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
2024-12-26 22:29:32 +00:00
def test_example_class_token_limit_3() -> None:
string_filler = " ".join(
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
)
code = f"""
class MyClass:
\"\"\"A class with a helper method. \"\"\"
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"{string_filler}\"\"\"
y = HelperClass().helper_method()
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
def helper_method(self):
return self.x
"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
# In this scenario, the read-writable code is too long, so we abort.
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
2025-06-08 07:30:47 +00:00
def test_example_class_token_limit_4() -> None:
string_filler = " ".join(
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
)
code = f"""
class MyClass:
\"\"\"A class with a helper method. \"\"\"
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
x = '{string_filler}'
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
def helper_method(self):
return self.x
"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
# In this scenario, the testgen code context is too long, so we abort.
with pytest.raises(ValueError, match="Testgen code context has exceeded token limit, cannot proceed"):
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
2025-06-08 07:30:47 +00:00
def test_repo_helper() -> None:
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
path_to_file = project_root / "main.py"
path_to_utils = project_root / "utils.py"
function_to_optimize = FunctionToOptimize(
function_name="fetch_and_process_data",
file_path=str(path_to_file),
parents=[],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
import math
import requests
from globals import API_URL
from utils import DataProcessor
class DataProcessor:
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def process_data(self, raw_data: str) -> str:
\"\"\"Process raw data by converting it to uppercase.\"\"\"
return raw_data.upper()
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
\"\"\"Add a prefix to the processed data.\"\"\"
return prefix + data
def fetch_and_process_data():
# Use the global variable for the request
response = requests.get(API_URL)
response.raise_for_status()
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
processed = processor.add_prefix(processed)
return processed
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
number = 1
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
2025-06-08 07:30:47 +00:00
"""
expected_hashing_context = f"""
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
def process_data(self, raw_data: str) -> str:
\"\"\"Process raw data by converting it to uppercase.\"\"\"
return raw_data.upper()
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
\"\"\"Add a prefix to the processed data.\"\"\"
return prefix + data
```
```python:{path_to_file.relative_to(project_root)}
def fetch_and_process_data():
# Use the global variable for the request
response = requests.get(API_URL)
response.raise_for_status()
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
processed = processor.add_prefix(processed)
return processed
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
def test_repo_helper_of_helper() -> None:
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
path_to_file = project_root / "main.py"
path_to_utils = project_root / "utils.py"
path_to_transform_utils = project_root / "transform_utils.py"
function_to_optimize = FunctionToOptimize(
function_name="fetch_and_transform_data",
file_path=str(path_to_file),
parents=[],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
import math
from transform_utils import DataTransformer
import requests
from globals import API_URL
from utils import DataProcessor
class DataProcessor:
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def process_data(self, raw_data: str) -> str:
\"\"\"Process raw data by converting it to uppercase.\"\"\"
return raw_data.upper()
def transform_data(self, data: str) -> str:
\"\"\"Transform the processed data\"\"\"
return DataTransformer().transform(data)
def fetch_and_transform_data():
# Use the global variable for the request
response = requests.get(API_URL)
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
transformed = processor.transform_data(processed)
return transformed
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
number = 1
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def transform(self, data):
self.data = data
return self.data
```
2025-06-08 07:30:47 +00:00
"""
expected_hashing_context = f"""
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
def process_data(self, raw_data: str) -> str:
\"\"\"Process raw data by converting it to uppercase.\"\"\"
return raw_data.upper()
def transform_data(self, data: str) -> str:
\"\"\"Transform the processed data\"\"\"
return DataTransformer().transform(data)
```
```python:{path_to_file.relative_to(project_root)}
def fetch_and_transform_data():
# Use the global variable for the request
response = requests.get(API_URL)
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
transformed = processor.transform_data(processed)
return transformed
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
def test_repo_helper_of_helper_same_class() -> None:
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
path_to_utils = project_root / "utils.py"
path_to_transform_utils = project_root / "transform_utils.py"
function_to_optimize = FunctionToOptimize(
function_name="transform_data_own_method",
file_path=str(path_to_utils),
parents=[FunctionParent(name="DataProcessor", type="ClassDef")],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
import math
from transform_utils import DataTransformer
class DataTransformer:
def __init__(self):
self.data = None
def transform_using_own_method(self, data):
return self.transform(data)
class DataProcessor:
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def transform_data_own_method(self, data: str) -> str:
\"\"\"Transform the processed data using own method\"\"\"
return DataTransformer().transform_using_own_method(data)
"""
expected_read_only_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def transform(self, data):
self.data = data
return self.data
```
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
number = 1
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
2025-06-08 07:30:47 +00:00
"""
expected_hashing_context = f"""
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
def transform_data_own_method(self, data: str) -> str:
\"\"\"Transform the processed data using own method\"\"\"
return DataTransformer().transform_using_own_method(data)
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
def test_repo_helper_of_helper_same_file() -> None:
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
path_to_utils = project_root / "utils.py"
path_to_transform_utils = project_root / "transform_utils.py"
function_to_optimize = FunctionToOptimize(
function_name="transform_data_same_file_function",
file_path=str(path_to_utils),
parents=[FunctionParent(name="DataProcessor", type="ClassDef")],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
import math
from transform_utils import DataTransformer
class DataTransformer:
def __init__(self):
self.data = None
def transform_using_same_file_function(self, data):
return update_data(data)
class DataProcessor:
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def transform_data_same_file_function(self, data: str) -> str:
\"\"\"Transform the processed data using a function from the same file\"\"\"
return DataTransformer().transform_using_same_file_function(data)
"""
expected_read_only_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
def update_data(data):
return data + " updated"
```
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
number = 1
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
2025-06-08 07:30:47 +00:00
"""
expected_hashing_context = f"""
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
def transform_data_same_file_function(self, data: str) -> str:
\"\"\"Transform the processed data using a function from the same file\"\"\"
return DataTransformer().transform_using_same_file_function(data)
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
def test_repo_helper_all_same_file() -> None:
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
path_to_transform_utils = project_root / "transform_utils.py"
function_to_optimize = FunctionToOptimize(
function_name="transform_data_all_same_file",
file_path=str(path_to_transform_utils),
parents=[FunctionParent(name="DataTransformer", type="ClassDef")],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
class DataTransformer:
def __init__(self):
self.data = None
def transform_using_own_method(self, data):
return self.transform(data)
def transform_data_all_same_file(self, data):
new_data = update_data(data)
return self.transform_using_own_method(new_data)
def update_data(data):
return data + " updated"
"""
expected_read_only_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def transform(self, data):
self.data = data
return self.data
```
2025-06-08 07:30:47 +00:00
"""
expected_hashing_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def transform_using_own_method(self, data):
return self.transform(data)
def transform_data_all_same_file(self, data):
new_data = update_data(data)
return self.transform_using_own_method(new_data)
def update_data(data):
return data + " updated"
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
def test_repo_helper_circular_dependency() -> None:
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
path_to_utils = project_root / "utils.py"
path_to_transform_utils = project_root / "transform_utils.py"
function_to_optimize = FunctionToOptimize(
function_name="circular_dependency",
file_path=str(path_to_transform_utils),
parents=[FunctionParent(name="DataTransformer", type="ClassDef")],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
import math
from transform_utils import DataTransformer
from code_to_optimize.code_directories.retriever.utils import DataProcessor
class DataProcessor:
def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)
def circular_dependency(self, data: str) -> str:
\"\"\"Test circular dependency\"\"\"
return DataTransformer().circular_dependency(data)
class DataTransformer:
def __init__(self):
self.data = None
def circular_dependency(self, data):
return DataProcessor().circular_dependency(data)
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
number = 1
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
2025-06-08 07:30:47 +00:00
"""
expected_hashing_context = f"""
```python:utils.py
class DataProcessor:
def circular_dependency(self, data: str) -> str:
return DataTransformer().circular_dependency(data)
```
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def circular_dependency(self, data):
return DataProcessor().circular_dependency(data)
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
def test_indirect_init_helper() -> None:
code = """
class MyClass:
def __init__(self):
self.x = 1
self.y = outside_method()
def target_method(self):
return self.x + self.y
def outside_method():
return 1
"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
class MyClass:
def __init__(self):
self.x = 1
self.y = outside_method()
def target_method(self):
return self.x + self.y
"""
expected_read_only_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
def outside_method():
return 1
```
2025-06-08 07:30:47 +00:00
"""
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def target_method(self):
return self.x + self.y
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
2025-04-16 18:14:05 +00:00
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
2025-04-16 18:14:05 +00:00
def test_direct_module_import() -> None:
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
path_to_main = project_root / "main.py"
path_to_fto = project_root / "import_test.py"
function_to_optimize = FunctionToOptimize(
function_name="function_to_optimize",
file_path=str(path_to_fto),
parents=[],
starting_line=None,
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
2025-04-17 22:44:15 +00:00
expected_read_only_context = """
```python:utils.py
from transform_utils import DataTransformer
class DataProcessor:
\"\"\"A class for processing data.\"\"\"
number = 1
def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={self.default_prefix!r})"
def process_data(self, raw_data: str) -> str:
\"\"\"Process raw data by converting it to uppercase.\"\"\"
return raw_data.upper()
def transform_data(self, data: str) -> str:
\"\"\"Transform the processed data\"\"\"
return DataTransformer().transform(data)
```"""
2025-06-08 07:30:47 +00:00
expected_hashing_context = """
```python:main.py
def fetch_and_transform_data():
# Use the global variable for the request
response = requests.get(API_URL)
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
transformed = processor.transform_data(processed)
return transformed
```
```python:import_test.py
def function_to_optimize():
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
```
"""
2025-04-17 22:44:15 +00:00
expected_read_write_context = """
import requests
from globals import API_URL
from utils import DataProcessor
import code_to_optimize.code_directories.retriever.main
def fetch_and_transform_data():
# Use the global variable for the request
response = requests.get(API_URL)
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
transformed = processor.transform_data(processed)
return transformed
def function_to_optimize():
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
2025-04-17 22:44:15 +00:00
2025-04-19 00:29:38 +00:00
def test_module_import_optimization() -> None:
2025-06-08 07:30:47 +00:00
main_code = """
2025-04-19 00:29:38 +00:00
import utility_module
class Calculator:
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
# This is where we use the imported module
self.precision = utility_module.select_precision(precision, fallback_precision)
self.mode = mode
# Using variables from the utility module
self.backend = utility_module.CALCULATION_BACKEND
self.system = utility_module.SYSTEM_TYPE
self.default_precision = utility_module.DEFAULT_PRECISION
def add(self, a, b):
return a + b
def subtract(self, a, b):
return a - b
def calculate(self, operation, x, y):
if operation == "add":
return self.add(x, y)
elif operation == "subtract":
return self.subtract(x, y)
2025-04-17 22:44:15 +00:00
else:
2025-04-19 00:29:38 +00:00
return None
2025-06-08 07:30:47 +00:00
"""
2025-04-19 00:29:38 +00:00
2025-06-08 07:30:47 +00:00
utility_module_code = """
2025-04-17 22:44:15 +00:00
import sys
import platform
2025-04-19 00:29:38 +00:00
import logging
2025-04-17 22:44:15 +00:00
2025-04-19 00:29:38 +00:00
DEFAULT_PRECISION = "medium"
DEFAULT_MODE = "standard"
2025-04-17 22:44:15 +00:00
2025-04-19 00:29:38 +00:00
# Try-except block with variable definitions
2025-04-17 22:44:15 +00:00
try:
2025-04-19 00:29:38 +00:00
import numpy as np
# Used variable in try block
CALCULATION_BACKEND = "numpy"
# Unused variable in try block
VECTOR_DIMENSIONS = 3
except ImportError:
# Used variable in except block
CALCULATION_BACKEND = "python"
# Unused variable in except block
FALLBACK_WARNING = "NumPy not available, using slower Python implementation"
# Nested if-else with variable definitions
if sys.platform.startswith('win'):
# Used variable in outer if
SYSTEM_TYPE = "windows"
if platform.architecture()[0] == '64bit':
# Unused variable in nested if
MEMORY_MODEL = "x64"
2025-04-17 22:44:15 +00:00
else:
2025-04-19 00:29:38 +00:00
# Unused variable in nested else
MEMORY_MODEL = "x86"
elif sys.platform.startswith('linux'):
# Used variable in outer elif
SYSTEM_TYPE = "linux"
# Unused variable in outer elif
KERNEL_VERSION = platform.release()
else:
# Used variable in outer else
SYSTEM_TYPE = "other"
# Unused variable in outer else
UNKNOWN_SYSTEM_MSG = "Running on an unrecognized platform"
# Function that will be used in the main code
def select_precision(precision, fallback_precision):
if precision is None:
return fallback_precision or DEFAULT_PRECISION
# Using the variables defined above
if CALCULATION_BACKEND == "numpy":
# Higher precision available with NumPy
precision_options = ["low", "medium", "high", "ultra"]
else:
# Limited precision without NumPy
precision_options = ["low", "medium", "high"]
if isinstance(precision, str):
if precision.lower() not in precision_options:
if fallback_precision:
return fallback_precision
else:
return DEFAULT_PRECISION
return precision.lower()
else:
return DEFAULT_PRECISION
# Function that won't be used
def get_system_details():
return {
"system": SYSTEM_TYPE,
"backend": CALCULATION_BACKEND,
"default_precision": DEFAULT_PRECISION,
"python_version": sys.version
}
2025-06-08 07:30:47 +00:00
"""
2025-04-17 22:44:15 +00:00
2025-04-19 00:29:38 +00:00
# Create a temporary directory for the test
with tempfile.TemporaryDirectory() as temp_dir:
# Set up the package structure
package_dir = Path(temp_dir) / "package"
package_dir.mkdir()
2025-04-17 22:44:15 +00:00
2025-04-19 00:29:38 +00:00
# Create the __init__.py file
with open(package_dir / "__init__.py", "w") as init_file:
init_file.write("")
2025-04-17 22:44:15 +00:00
2025-04-19 00:29:38 +00:00
# Write the utility_module.py file
with open(package_dir / "utility_module.py", "w") as utility_file:
utility_file.write(utility_module_code)
utility_file.flush()
# Write the main code file
main_file_path = package_dir / "main_module.py"
with open(main_file_path, "w") as main_file:
main_file.write(main_code)
main_file.flush()
# Set up the optimizer
file_path = main_file_path.resolve()
opt = Optimizer(
Namespace(
project_root=package_dir.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
# Define the function to optimize
function_to_optimize = FunctionToOptimize(
function_name="calculate",
file_path=file_path,
parents=[FunctionParent(name="Calculator", type="ClassDef")],
starting_line=None,
ending_line=None,
)
2025-04-17 22:44:15 +00:00
2025-04-19 00:29:38 +00:00
# Get the code optimization context
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-06-08 07:30:47 +00:00
hashing_context = code_ctx.hashing_code_context
2025-04-19 00:29:38 +00:00
# The expected contexts
expected_read_write_context = """
import utility_module
class Calculator:
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
# This is where we use the imported module
self.precision = utility_module.select_precision(precision, fallback_precision)
self.mode = mode
# Using variables from the utility module
self.backend = utility_module.CALCULATION_BACKEND
self.system = utility_module.SYSTEM_TYPE
self.default_precision = utility_module.DEFAULT_PRECISION
def add(self, a, b):
return a + b
def subtract(self, a, b):
return a - b
def calculate(self, operation, x, y):
if operation == "add":
return self.add(x, y)
elif operation == "subtract":
return self.subtract(x, y)
else:
return None
"""
expected_read_only_context = """
```python:utility_module.py
DEFAULT_PRECISION = "medium"
# Try-except block with variable definitions
2025-04-17 22:44:15 +00:00
try:
2025-04-19 00:29:38 +00:00
# Used variable in try block
CALCULATION_BACKEND = "numpy"
except ImportError:
# Used variable in except block
CALCULATION_BACKEND = "python"
# Function that will be used in the main code
def select_precision(precision, fallback_precision):
if precision is None:
return fallback_precision or DEFAULT_PRECISION
# Using the variables defined above
if CALCULATION_BACKEND == "numpy":
# Higher precision available with NumPy
precision_options = ["low", "medium", "high", "ultra"]
else:
# Limited precision without NumPy
precision_options = ["low", "medium", "high"]
if isinstance(precision, str):
if precision.lower() not in precision_options:
if fallback_precision:
return fallback_precision
else:
return DEFAULT_PRECISION
return precision.lower()
else:
return DEFAULT_PRECISION
```
2025-06-08 07:30:47 +00:00
"""
expected_hashing_context = """
```python:main_module.py
class Calculator:
def add(self, a, b):
return a + b
def subtract(self, a, b):
return a - b
def calculate(self, operation, x, y):
if operation == "add":
return self.add(x, y)
elif operation == "subtract":
return self.subtract(x, y)
else:
return None
```
2025-04-19 00:29:38 +00:00
"""
# Verify the contexts match the expected values
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
2025-06-08 07:30:47 +00:00
assert hashing_context.strip() == expected_hashing_context.strip()
2025-04-19 00:29:38 +00:00
def test_module_import_init_fto() -> None:
2025-06-08 07:30:47 +00:00
main_code = """
2025-04-19 00:29:38 +00:00
import utility_module
2025-04-17 22:44:15 +00:00
2025-04-19 00:29:38 +00:00
class Calculator:
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
# This is where we use the imported module
self.precision = utility_module.select_precision(precision, fallback_precision)
self.mode = mode
2025-04-17 22:44:15 +00:00
2025-04-19 00:29:38 +00:00
# Using variables from the utility module
self.backend = utility_module.CALCULATION_BACKEND
self.system = utility_module.SYSTEM_TYPE
self.default_precision = utility_module.DEFAULT_PRECISION
def add(self, a, b):
return a + b
def subtract(self, a, b):
return a - b
def calculate(self, operation, x, y):
if operation == "add":
return self.add(x, y)
elif operation == "subtract":
return self.subtract(x, y)
else:
return None
2025-06-08 07:30:47 +00:00
"""
2025-04-17 22:44:15 +00:00
2025-06-08 07:30:47 +00:00
utility_module_code = """
2025-04-19 00:29:38 +00:00
import sys
import platform
import logging
DEFAULT_PRECISION = "medium"
DEFAULT_MODE = "standard"
# Try-except block with variable definitions
try:
import numpy as np
# Used variable in try block
CALCULATION_BACKEND = "numpy"
# Unused variable in try block
VECTOR_DIMENSIONS = 3
except ImportError:
# Used variable in except block
CALCULATION_BACKEND = "python"
# Unused variable in except block
FALLBACK_WARNING = "NumPy not available, using slower Python implementation"
# Nested if-else with variable definitions
if sys.platform.startswith('win'):
# Used variable in outer if
SYSTEM_TYPE = "windows"
if platform.architecture()[0] == '64bit':
# Unused variable in nested if
MEMORY_MODEL = "x64"
else:
# Unused variable in nested else
MEMORY_MODEL = "x86"
elif sys.platform.startswith('linux'):
# Used variable in outer elif
SYSTEM_TYPE = "linux"
# Unused variable in outer elif
KERNEL_VERSION = platform.release()
else:
# Used variable in outer else
SYSTEM_TYPE = "other"
# Unused variable in outer else
UNKNOWN_SYSTEM_MSG = "Running on an unrecognized platform"
# Function that will be used in the main code
def select_precision(precision, fallback_precision):
if precision is None:
return fallback_precision or DEFAULT_PRECISION
# Using the variables defined above
if CALCULATION_BACKEND == "numpy":
# Higher precision available with NumPy
precision_options = ["low", "medium", "high", "ultra"]
else:
# Limited precision without NumPy
precision_options = ["low", "medium", "high"]
if isinstance(precision, str):
if precision.lower() not in precision_options:
if fallback_precision:
return fallback_precision
else:
return DEFAULT_PRECISION
return precision.lower()
else:
return DEFAULT_PRECISION
# Function that won't be used
def get_system_details():
return {
"system": SYSTEM_TYPE,
"backend": CALCULATION_BACKEND,
"default_precision": DEFAULT_PRECISION,
"python_version": sys.version
}
2025-06-08 07:30:47 +00:00
"""
2025-04-19 00:29:38 +00:00
# Create a temporary directory for the test
2025-04-17 22:44:15 +00:00
with tempfile.TemporaryDirectory() as temp_dir:
2025-04-19 00:29:38 +00:00
# Set up the package structure
2025-04-17 22:44:15 +00:00
package_dir = Path(temp_dir) / "package"
package_dir.mkdir()
2025-04-19 00:29:38 +00:00
# Create the __init__.py file
2025-04-17 22:44:15 +00:00
with open(package_dir / "__init__.py", "w") as init_file:
init_file.write("")
2025-04-19 00:29:38 +00:00
# Write the utility_module.py file
with open(package_dir / "utility_module.py", "w") as utility_file:
utility_file.write(utility_module_code)
utility_file.flush()
2025-04-17 22:44:15 +00:00
2025-04-19 00:29:38 +00:00
# Write the main code file
2025-04-17 22:44:15 +00:00
main_file_path = package_dir / "main_module.py"
with open(main_file_path, "w") as main_file:
2025-04-19 00:29:38 +00:00
main_file.write(main_code)
2025-04-17 22:44:15 +00:00
main_file.flush()
2025-04-19 00:29:38 +00:00
# Set up the optimizer
2025-04-17 22:44:15 +00:00
file_path = main_file_path.resolve()
opt = Optimizer(
Namespace(
project_root=package_dir.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
2025-04-19 00:29:38 +00:00
# Define the function to optimize
2025-04-17 22:44:15 +00:00
function_to_optimize = FunctionToOptimize(
2025-04-19 00:29:38 +00:00
function_name="__init__",
2025-04-17 22:44:15 +00:00
file_path=file_path,
2025-04-19 00:29:38 +00:00
parents=[FunctionParent(name="Calculator", type="ClassDef")],
2025-04-17 22:44:15 +00:00
starting_line=None,
ending_line=None,
)
2025-04-19 00:29:38 +00:00
# Get the code optimization context
2025-04-17 22:44:15 +00:00
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
2025-04-19 00:29:38 +00:00
# The expected contexts
2025-04-17 22:44:15 +00:00
expected_read_write_context = """
2025-04-19 00:29:38 +00:00
# Function that will be used in the main code
import utility_module
def select_precision(precision, fallback_precision):
if precision is None:
return fallback_precision or DEFAULT_PRECISION
# Using the variables defined above
if CALCULATION_BACKEND == "numpy":
# Higher precision available with NumPy
precision_options = ["low", "medium", "high", "ultra"]
else:
# Limited precision without NumPy
precision_options = ["low", "medium", "high"]
if isinstance(precision, str):
if precision.lower() not in precision_options:
if fallback_precision:
return fallback_precision
else:
return DEFAULT_PRECISION
return precision.lower()
else:
return DEFAULT_PRECISION
2025-04-17 22:44:15 +00:00
2025-04-19 00:29:38 +00:00
class Calculator:
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
# This is where we use the imported module
self.precision = utility_module.select_precision(precision, fallback_precision)
self.mode = mode
2025-04-17 22:44:15 +00:00
2025-04-19 00:29:38 +00:00
# Using variables from the utility module
self.backend = utility_module.CALCULATION_BACKEND
self.system = utility_module.SYSTEM_TYPE
self.default_precision = utility_module.DEFAULT_PRECISION
"""
expected_read_only_context = """
```python:utility_module.py
DEFAULT_PRECISION = "medium"
# Try-except block with variable definitions
try:
# Used variable in try block
CALCULATION_BACKEND = "numpy"
except ImportError:
# Used variable in except block
CALCULATION_BACKEND = "python"
2025-04-17 22:44:15 +00:00
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
2025-06-08 07:30:47 +00:00
assert read_only_context.strip() == expected_read_only_context.strip()