2468 lines
83 KiB
Python
2468 lines
83 KiB
Python
from __future__ import annotations
|
|
|
|
import sys
|
|
import tempfile
|
|
from argparse import Namespace
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from codeflash.context.code_context_extractor import get_code_optimization_context
|
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
|
from codeflash.models.models import FunctionParent, get_code_block_splitter
|
|
from codeflash.optimization.optimizer import Optimizer
|
|
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
|
from codeflash.code_utils.code_extractor import add_global_assignments
|
|
|
|
|
|
class HelperClass:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def innocent_bystander(self):
|
|
pass
|
|
|
|
def helper_method(self):
|
|
return self.name
|
|
|
|
class NestedClass:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def nested_method(self):
|
|
return self.name
|
|
|
|
|
|
def main_method():
|
|
return "hello"
|
|
|
|
|
|
class MainClass:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def main_method(self):
|
|
self.name = HelperClass.NestedClass("test").nested_method()
|
|
return HelperClass(self.name).helper_method()
|
|
|
|
|
|
class Graph:
|
|
def __init__(self, vertices):
|
|
self.graph = defaultdict(list)
|
|
self.V = vertices # No. of vertices
|
|
|
|
def addEdge(self, u, v):
|
|
self.graph[u].append(v)
|
|
|
|
def topologicalSortUtil(self, v, visited, stack):
|
|
visited[v] = True
|
|
|
|
for i in self.graph[v]:
|
|
if visited[i] == False:
|
|
self.topologicalSortUtil(i, visited, stack)
|
|
|
|
stack.insert(0, v)
|
|
|
|
def topologicalSort(self):
|
|
visited = [False] * self.V
|
|
stack = []
|
|
|
|
for i in range(self.V):
|
|
if visited[i] == False:
|
|
self.topologicalSortUtil(i, visited, stack)
|
|
|
|
# Print contents of stack
|
|
return stack
|
|
|
|
|
|
def test_code_replacement10() -> None:
|
|
file_path = Path(__file__).resolve()
|
|
|
|
func_top_optimize = FunctionToOptimize(
|
|
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
|
|
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
|
|
assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter(file_path.relative_to(file_path.parent))}
|
|
from __future__ import annotations
|
|
|
|
class HelperClass:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def helper_method(self):
|
|
return self.name
|
|
|
|
|
|
class MainClass:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def main_method(self):
|
|
self.name = HelperClass.NestedClass("test").nested_method()
|
|
return HelperClass(self.name).helper_method()
|
|
"""
|
|
expected_read_only_context = """
|
|
"""
|
|
|
|
expected_hashing_context = f"""
|
|
```python:{file_path.relative_to(file_path.parent)}
|
|
class HelperClass:
|
|
|
|
def helper_method(self):
|
|
return self.name
|
|
|
|
class MainClass:
|
|
|
|
def main_method(self):
|
|
self.name = HelperClass.NestedClass('test').nested_method()
|
|
return HelperClass(self.name).helper_method()
|
|
```
|
|
"""
|
|
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_class_method_dependencies() -> None:
|
|
file_path = Path(__file__).resolve()
|
|
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="topologicalSort",
|
|
file_path=str(file_path),
|
|
parents=[FunctionParent(name="Graph", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, file_path.parent.resolve())
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter(file_path.relative_to(file_path.parent))}
|
|
from __future__ import annotations
|
|
from collections import defaultdict
|
|
|
|
class Graph:
|
|
def __init__(self, vertices):
|
|
self.graph = defaultdict(list)
|
|
self.V = vertices # No. of vertices
|
|
|
|
def topologicalSortUtil(self, v, visited, stack):
|
|
visited[v] = True
|
|
|
|
for i in self.graph[v]:
|
|
if visited[i] == False:
|
|
self.topologicalSortUtil(i, visited, stack)
|
|
|
|
stack.insert(0, v)
|
|
|
|
def topologicalSort(self):
|
|
visited = [False] * self.V
|
|
stack = []
|
|
|
|
for i in range(self.V):
|
|
if visited[i] == False:
|
|
self.topologicalSortUtil(i, visited, stack)
|
|
|
|
# Print contents of stack
|
|
return stack
|
|
|
|
"""
|
|
expected_read_only_context = ""
|
|
|
|
expected_hashing_context = f"""
|
|
```python:{file_path.relative_to(file_path.parent.resolve())}
|
|
class Graph:
|
|
|
|
def topologicalSortUtil(self, v, visited, stack):
|
|
visited[v] = True
|
|
for i in self.graph[v]:
|
|
if visited[i] == False:
|
|
self.topologicalSortUtil(i, visited, stack)
|
|
stack.insert(0, v)
|
|
|
|
def topologicalSort(self):
|
|
visited = [False] * self.V
|
|
stack = []
|
|
for i in range(self.V):
|
|
if visited[i] == False:
|
|
self.topologicalSortUtil(i, visited, stack)
|
|
return stack
|
|
```
|
|
"""
|
|
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_bubble_sort_helper() -> None:
|
|
path_to_fto = (
|
|
Path(__file__).resolve().parent.parent
|
|
/ "code_to_optimize"
|
|
/ "code_directories"
|
|
/ "retriever"
|
|
/ "bubble_sort_imported.py"
|
|
)
|
|
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="sort_from_another_file",
|
|
file_path=str(path_to_fto),
|
|
parents=[],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, Path(__file__).resolve().parent.parent)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter("code_to_optimize/code_directories/retriever/bubble_sort_with_math.py")}
|
|
import math
|
|
|
|
def sorter(arr):
|
|
arr.sort()
|
|
x = math.sqrt(2)
|
|
print(x)
|
|
return arr
|
|
|
|
{get_code_block_splitter("code_to_optimize/code_directories/retriever/bubble_sort_imported.py")}
|
|
from bubble_sort_with_math import sorter
|
|
|
|
def sort_from_another_file(arr):
|
|
sorted_arr = sorter(arr)
|
|
return sorted_arr
|
|
|
|
"""
|
|
expected_read_only_context = ""
|
|
|
|
expected_hashing_context = """
|
|
```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py
|
|
def sorter(arr):
|
|
arr.sort()
|
|
x = math.sqrt(2)
|
|
print(x)
|
|
return arr
|
|
```
|
|
```python:code_to_optimize/code_directories/retriever/bubble_sort_imported.py
|
|
def sort_from_another_file(arr):
|
|
sorted_arr = sorter(arr)
|
|
return sorted_arr
|
|
```
|
|
"""
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_flavio_typed_code_helper() -> None:
|
|
code = '''
|
|
|
|
_P = ParamSpec("_P")
|
|
_KEY_T = TypeVar("_KEY_T")
|
|
_STORE_T = TypeVar("_STORE_T")
|
|
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
|
"""Interface for cache backends used by the persistent cache decorator."""
|
|
|
|
def __init__(self) -> None: ...
|
|
|
|
def hash_key(
|
|
self,
|
|
*,
|
|
func: Callable[_P, Any],
|
|
args: tuple[Any, ...],
|
|
kwargs: dict[str, Any],
|
|
) -> tuple[str, _KEY_T]: ...
|
|
|
|
def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401
|
|
...
|
|
|
|
def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401
|
|
...
|
|
|
|
def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ...
|
|
|
|
def delete(self, *, key: tuple[str, _KEY_T]) -> None: ...
|
|
|
|
def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ...
|
|
|
|
def get_cache_or_call(
|
|
self,
|
|
*,
|
|
func: Callable[_P, Any],
|
|
args: tuple[Any, ...],
|
|
kwargs: dict[str, Any],
|
|
lifespan: datetime.timedelta,
|
|
) -> Any: # noqa: ANN401
|
|
"""
|
|
Retrieve the cached results for a function call.
|
|
|
|
Args:
|
|
----
|
|
func (Callable[..., _R]): The function to retrieve cached results for.
|
|
args (tuple[Any, ...]): The positional arguments passed to the function.
|
|
kwargs (dict[str, Any]): The keyword arguments passed to the function.
|
|
lifespan (datetime.timedelta): The maximum age of the cached results.
|
|
|
|
Returns:
|
|
-------
|
|
_R: The cached results, if available.
|
|
|
|
"""
|
|
if os.environ.get("NO_CACHE"):
|
|
return func(*args, **kwargs)
|
|
|
|
try:
|
|
key = self.hash_key(func=func, args=args, kwargs=kwargs)
|
|
except: # noqa: E722
|
|
# If we can't create a cache key, we should just call the function.
|
|
logging.warning("Failed to hash cache key for function: %s", func)
|
|
return func(*args, **kwargs)
|
|
result_pair = self.get(key=key)
|
|
|
|
if result_pair is not None:
|
|
cached_time, result = result_pair
|
|
if not os.environ.get("RE_CACHE") and (
|
|
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
|
|
):
|
|
try:
|
|
return self.decode(data=result)
|
|
except CacheBackendDecodeError as e:
|
|
logging.warning("Failed to decode cache data: %s", e)
|
|
# If decoding fails we will treat this as a cache miss.
|
|
# This might happens if underlying class definition of the data changes.
|
|
self.delete(key=key)
|
|
result = func(*args, **kwargs)
|
|
try:
|
|
self.put(key=key, data=self.encode(data=result))
|
|
except CacheBackendEncodeError as e:
|
|
logging.warning("Failed to encode cache data: %s", e)
|
|
# If encoding fails, we should still return the result.
|
|
return result
|
|
|
|
_P = ParamSpec("_P")
|
|
_R = TypeVar("_R")
|
|
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
|
|
|
|
|
|
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|
"""
|
|
A decorator class that provides persistent caching functionality for a function.
|
|
|
|
Args:
|
|
----
|
|
func (Callable[_P, _R]): The function to be decorated.
|
|
duration (datetime.timedelta): The duration for which the cached results should be considered valid.
|
|
backend (_backend): The backend storage for the cached results.
|
|
|
|
Attributes:
|
|
----------
|
|
__wrapped__ (Callable[_P, _R]): The wrapped function.
|
|
__duration__ (datetime.timedelta): The duration for which the cached results should be considered valid.
|
|
__backend__ (_backend): The backend storage for the cached results.
|
|
|
|
""" # noqa: E501
|
|
|
|
__wrapped__: Callable[_P, _R]
|
|
__duration__: datetime.timedelta
|
|
__backend__: _CacheBackendT
|
|
|
|
def __init__(
|
|
self,
|
|
func: Callable[_P, _R],
|
|
duration: datetime.timedelta,
|
|
) -> None:
|
|
self.__wrapped__ = func
|
|
self.__duration__ = duration
|
|
self.__backend__ = AbstractCacheBackend()
|
|
functools.update_wrapper(self, func)
|
|
|
|
def cache_clear(self) -> None:
|
|
"""Clears the cache for the wrapped function."""
|
|
self.__backend__.del_func_cache(func=self.__wrapped__)
|
|
|
|
def no_cache_call(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
"""
|
|
Calls the wrapped function without using the cache.
|
|
|
|
Args:
|
|
----
|
|
*args (_P.args): Positional arguments for the wrapped function.
|
|
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
|
|
|
|
Returns:
|
|
-------
|
|
_R: The result of the wrapped function.
|
|
|
|
"""
|
|
return self.__wrapped__(*args, **kwargs)
|
|
|
|
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
"""
|
|
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
|
|
|
|
Args:
|
|
----
|
|
*args (_P.args): Positional arguments for the wrapped function.
|
|
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
|
|
|
|
Returns:
|
|
-------
|
|
_R: The result of the wrapped function.
|
|
|
|
""" # noqa: E501
|
|
if "NO_CACHE" in os.environ:
|
|
return self.__wrapped__(*args, **kwargs)
|
|
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
|
|
return self.__backend__.get_cache_or_call(
|
|
func=self.__wrapped__,
|
|
args=args,
|
|
kwargs=kwargs,
|
|
lifespan=self.__duration__,
|
|
)
|
|
'''
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
|
|
f.write(code)
|
|
f.flush()
|
|
file_path = Path(f.name).resolve()
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="__call__",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="_PersistentCache", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
|
|
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
|
|
|
def __init__(self) -> None: ...
|
|
|
|
def get_cache_or_call(
|
|
self,
|
|
*,
|
|
func: Callable[_P, Any],
|
|
args: tuple[Any, ...],
|
|
kwargs: dict[str, Any],
|
|
lifespan: datetime.timedelta,
|
|
) -> Any: # noqa: ANN401
|
|
\"\"\"
|
|
Retrieve the cached results for a function call.
|
|
|
|
Args:
|
|
----
|
|
func (Callable[..., _R]): The function to retrieve cached results for.
|
|
args (tuple[Any, ...]): The positional arguments passed to the function.
|
|
kwargs (dict[str, Any]): The keyword arguments passed to the function.
|
|
lifespan (datetime.timedelta): The maximum age of the cached results.
|
|
|
|
Returns:
|
|
-------
|
|
_R: The cached results, if available.
|
|
|
|
\"\"\"
|
|
if os.environ.get("NO_CACHE"):
|
|
return func(*args, **kwargs)
|
|
|
|
try:
|
|
key = self.hash_key(func=func, args=args, kwargs=kwargs)
|
|
except: # noqa: E722
|
|
# If we can't create a cache key, we should just call the function.
|
|
logging.warning("Failed to hash cache key for function: %s", func)
|
|
return func(*args, **kwargs)
|
|
result_pair = self.get(key=key)
|
|
|
|
if result_pair is not None:
|
|
cached_time, result = result_pair
|
|
if not os.environ.get("RE_CACHE") and (
|
|
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
|
|
):
|
|
try:
|
|
return self.decode(data=result)
|
|
except CacheBackendDecodeError as e:
|
|
logging.warning("Failed to decode cache data: %s", e)
|
|
# If decoding fails we will treat this as a cache miss.
|
|
# This might happens if underlying class definition of the data changes.
|
|
self.delete(key=key)
|
|
result = func(*args, **kwargs)
|
|
try:
|
|
self.put(key=key, data=self.encode(data=result))
|
|
except CacheBackendEncodeError as e:
|
|
logging.warning("Failed to encode cache data: %s", e)
|
|
# If encoding fails, we should still return the result.
|
|
return result
|
|
|
|
|
|
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|
|
|
def __init__(
|
|
self,
|
|
func: Callable[_P, _R],
|
|
duration: datetime.timedelta,
|
|
) -> None:
|
|
self.__wrapped__ = func
|
|
self.__duration__ = duration
|
|
self.__backend__ = AbstractCacheBackend()
|
|
functools.update_wrapper(self, func)
|
|
|
|
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
\"\"\"
|
|
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
|
|
|
|
Args:
|
|
----
|
|
*args (_P.args): Positional arguments for the wrapped function.
|
|
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
|
|
|
|
Returns:
|
|
-------
|
|
_R: The result of the wrapped function.
|
|
|
|
\"\"\" # noqa: E501
|
|
if "NO_CACHE" in os.environ:
|
|
return self.__wrapped__(*args, **kwargs)
|
|
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
|
|
return self.__backend__.get_cache_or_call(
|
|
func=self.__wrapped__,
|
|
args=args,
|
|
kwargs=kwargs,
|
|
lifespan=self.__duration__,
|
|
)
|
|
"""
|
|
expected_read_only_context = f'''
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
_P = ParamSpec("_P")
|
|
_KEY_T = TypeVar("_KEY_T")
|
|
_STORE_T = TypeVar("_STORE_T")
|
|
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
|
"""Interface for cache backends used by the persistent cache decorator."""
|
|
|
|
def hash_key(
|
|
self,
|
|
*,
|
|
func: Callable[_P, Any],
|
|
args: tuple[Any, ...],
|
|
kwargs: dict[str, Any],
|
|
) -> tuple[str, _KEY_T]: ...
|
|
|
|
def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401
|
|
...
|
|
|
|
def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401
|
|
...
|
|
|
|
def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ...
|
|
|
|
def delete(self, *, key: tuple[str, _KEY_T]) -> None: ...
|
|
|
|
def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ...
|
|
|
|
_P = ParamSpec("_P")
|
|
_R = TypeVar("_R")
|
|
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
|
|
|
|
|
|
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|
"""
|
|
A decorator class that provides persistent caching functionality for a function.
|
|
|
|
Args:
|
|
----
|
|
func (Callable[_P, _R]): The function to be decorated.
|
|
duration (datetime.timedelta): The duration for which the cached results should be considered valid.
|
|
backend (_backend): The backend storage for the cached results.
|
|
|
|
Attributes:
|
|
----------
|
|
__wrapped__ (Callable[_P, _R]): The wrapped function.
|
|
__duration__ (datetime.timedelta): The duration for which the cached results should be considered valid.
|
|
__backend__ (_backend): The backend storage for the cached results.
|
|
|
|
""" # noqa: E501
|
|
|
|
__wrapped__: Callable[_P, _R]
|
|
__duration__: datetime.timedelta
|
|
__backend__: _CacheBackendT
|
|
```
|
|
'''
|
|
expected_hashing_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
|
|
|
def get_cache_or_call(self, *, func: Callable[_P, Any], args: tuple[Any, ...], kwargs: dict[str, Any], lifespan: datetime.timedelta) -> Any:
|
|
if os.environ.get('NO_CACHE'):
|
|
return func(*args, **kwargs)
|
|
try:
|
|
key = self.hash_key(func=func, args=args, kwargs=kwargs)
|
|
except:
|
|
logging.warning('Failed to hash cache key for function: %s', func)
|
|
return func(*args, **kwargs)
|
|
result_pair = self.get(key=key)
|
|
if result_pair is not None:
|
|
{"cached_time, result = result_pair" if sys.version_info >= (3, 11) else "(cached_time, result) = result_pair"}
|
|
if not os.environ.get('RE_CACHE') and datetime.datetime.now() < cached_time + lifespan:
|
|
try:
|
|
return self.decode(data=result)
|
|
except CacheBackendDecodeError as e:
|
|
logging.warning('Failed to decode cache data: %s', e)
|
|
self.delete(key=key)
|
|
result = func(*args, **kwargs)
|
|
try:
|
|
self.put(key=key, data=self.encode(data=result))
|
|
except CacheBackendEncodeError as e:
|
|
logging.warning('Failed to encode cache data: %s', e)
|
|
return result
|
|
|
|
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|
|
|
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
if 'NO_CACHE' in os.environ:
|
|
return self.__wrapped__(*args, **kwargs)
|
|
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
|
|
return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__)
|
|
```
|
|
"""
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_example_class() -> None:
|
|
code = """
|
|
class MyClass:
|
|
\"\"\"A class with a helper method.\"\"\"
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
\"\"\"A helper class for MyClass.\"\"\"
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def __repr__(self):
|
|
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
|
return "HelperClass" + str(self.x)
|
|
def helper_method(self):
|
|
return self.x
|
|
"""
|
|
with tempfile.NamedTemporaryFile(mode="w") as f:
|
|
f.write(code)
|
|
f.flush()
|
|
file_path = Path(f.name).resolve()
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
|
|
class MyClass:
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def helper_method(self):
|
|
return self.x
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
\"\"\"A class with a helper method.\"\"\"
|
|
|
|
class HelperClass:
|
|
\"\"\"A helper class for MyClass.\"\"\"
|
|
def __repr__(self):
|
|
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
|
return "HelperClass" + str(self.x)
|
|
```
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
|
|
def target_method(self):
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
|
|
def helper_method(self):
|
|
return self.x
|
|
```
|
|
"""
|
|
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_example_class_token_limit_1() -> None:
|
|
docstring_filler = " ".join(
|
|
["This is a long docstring that will be used to fill up the token limit." for _ in range(1000)]
|
|
)
|
|
code = f"""
|
|
class MyClass:
|
|
\"\"\"A class with a helper method.
|
|
{docstring_filler}\"\"\"
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
\"\"\"Docstring for target method\"\"\"
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
\"\"\"A helper class for MyClass.\"\"\"
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def __repr__(self):
|
|
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
|
return "HelperClass" + str(self.x)
|
|
def helper_method(self):
|
|
return self.x
|
|
"""
|
|
with tempfile.NamedTemporaryFile(mode="w") as f:
|
|
f.write(code)
|
|
f.flush()
|
|
file_path = Path(f.name).resolve()
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
|
|
class MyClass:
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
\"\"\"Docstring for target method\"\"\"
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def helper_method(self):
|
|
return self.x
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
pass
|
|
|
|
class HelperClass:
|
|
def __repr__(self):
|
|
return "HelperClass" + str(self.x)
|
|
```
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
|
|
def target_method(self):
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
|
|
def helper_method(self):
|
|
return self.x
|
|
```
|
|
"""
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_example_class_token_limit_2() -> None:
|
|
string_filler = " ".join(
|
|
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
|
|
)
|
|
code = f"""
|
|
class MyClass:
|
|
\"\"\"A class with a helper method. \"\"\"
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
\"\"\"Docstring for target method\"\"\"
|
|
y = HelperClass().helper_method()
|
|
x = '{string_filler}'
|
|
|
|
class HelperClass:
|
|
\"\"\"A helper class for MyClass.\"\"\"
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def __repr__(self):
|
|
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
|
return "HelperClass" + str(self.x)
|
|
def helper_method(self):
|
|
return self.x
|
|
"""
|
|
with tempfile.NamedTemporaryFile(mode="w") as f:
|
|
f.write(code)
|
|
f.flush()
|
|
file_path = Path(f.name).resolve()
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
|
|
class MyClass:
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
\"\"\"Docstring for target method\"\"\"
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def helper_method(self):
|
|
return self.x
|
|
"""
|
|
expected_read_only_context = ""
|
|
expected_hashing_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
|
|
def target_method(self):
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
|
|
def helper_method(self):
|
|
return self.x
|
|
```
|
|
"""
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_example_class_token_limit_3() -> None:
|
|
string_filler = " ".join(
|
|
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
|
|
)
|
|
code = f"""
|
|
class MyClass:
|
|
\"\"\"A class with a helper method. \"\"\"
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
\"\"\"{string_filler}\"\"\"
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
\"\"\"A helper class for MyClass.\"\"\"
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def __repr__(self):
|
|
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
|
return "HelperClass" + str(self.x)
|
|
def helper_method(self):
|
|
return self.x
|
|
"""
|
|
with tempfile.NamedTemporaryFile(mode="w") as f:
|
|
f.write(code)
|
|
f.flush()
|
|
file_path = Path(f.name).resolve()
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
# In this scenario, the read-writable code is too long, so we abort.
|
|
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
|
|
|
|
def test_example_class_token_limit_4() -> None:
|
|
string_filler = " ".join(
|
|
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
|
|
)
|
|
code = f"""
|
|
class MyClass:
|
|
\"\"\"A class with a helper method. \"\"\"
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
\"\"\"Docstring for target method\"\"\"
|
|
y = HelperClass().helper_method()
|
|
x = '{string_filler}'
|
|
|
|
class HelperClass:
|
|
\"\"\"A helper class for MyClass.\"\"\"
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def __repr__(self):
|
|
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
|
return "HelperClass" + str(self.x)
|
|
def helper_method(self):
|
|
return self.x
|
|
"""
|
|
with tempfile.NamedTemporaryFile(mode="w") as f:
|
|
f.write(code)
|
|
f.flush()
|
|
file_path = Path(f.name).resolve()
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
# In this scenario, the testgen code context is too long, so we abort.
|
|
with pytest.raises(ValueError, match="Testgen code context has exceeded token limit, cannot proceed"):
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
|
|
|
|
def test_repo_helper() -> None:
|
|
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
|
path_to_file = project_root / "main.py"
|
|
path_to_utils = project_root / "utils.py"
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="fetch_and_process_data",
|
|
file_path=str(path_to_file),
|
|
parents=[],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
|
|
import math
|
|
|
|
class DataProcessor:
|
|
|
|
def __init__(self, default_prefix: str = "PREFIX_"):
|
|
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
|
self.default_prefix = default_prefix
|
|
self.number += math.log(self.number)
|
|
|
|
def process_data(self, raw_data: str) -> str:
|
|
\"\"\"Process raw data by converting it to uppercase.\"\"\"
|
|
return raw_data.upper()
|
|
|
|
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
|
|
\"\"\"Add a prefix to the processed data.\"\"\"
|
|
return prefix + data
|
|
|
|
{get_code_block_splitter(path_to_file.relative_to(project_root))}
|
|
import requests
|
|
from globals import API_URL
|
|
from utils import DataProcessor
|
|
|
|
def fetch_and_process_data():
|
|
# Use the global variable for the request
|
|
response = requests.get(API_URL)
|
|
response.raise_for_status()
|
|
|
|
raw_data = response.text
|
|
|
|
# Use code from another file (utils.py)
|
|
processor = DataProcessor()
|
|
processed = processor.process_data(raw_data)
|
|
processed = processor.add_prefix(processed)
|
|
|
|
return processed
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
\"\"\"A class for processing data.\"\"\"
|
|
|
|
number = 1
|
|
|
|
def __repr__(self) -> str:
|
|
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
|
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
|
```
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
|
|
def process_data(self, raw_data: str) -> str:
|
|
return raw_data.upper()
|
|
|
|
def add_prefix(self, data: str, prefix: str='PREFIX_') -> str:
|
|
return prefix + data
|
|
```
|
|
```python:{path_to_file.relative_to(project_root)}
|
|
def fetch_and_process_data():
|
|
response = requests.get(API_URL)
|
|
response.raise_for_status()
|
|
raw_data = response.text
|
|
processor = DataProcessor()
|
|
processed = processor.process_data(raw_data)
|
|
processed = processor.add_prefix(processed)
|
|
return processed
|
|
```
|
|
"""
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_repo_helper_of_helper() -> None:
|
|
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
|
path_to_file = project_root / "main.py"
|
|
path_to_utils = project_root / "utils.py"
|
|
path_to_transform_utils = project_root / "transform_utils.py"
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="fetch_and_transform_data",
|
|
file_path=str(path_to_file),
|
|
parents=[],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
|
|
import math
|
|
from transform_utils import DataTransformer
|
|
|
|
class DataProcessor:
|
|
|
|
def __init__(self, default_prefix: str = "PREFIX_"):
|
|
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
|
self.default_prefix = default_prefix
|
|
self.number += math.log(self.number)
|
|
|
|
def process_data(self, raw_data: str) -> str:
|
|
\"\"\"Process raw data by converting it to uppercase.\"\"\"
|
|
return raw_data.upper()
|
|
|
|
def transform_data(self, data: str) -> str:
|
|
\"\"\"Transform the processed data\"\"\"
|
|
return DataTransformer().transform(data)
|
|
|
|
{get_code_block_splitter(path_to_file.relative_to(project_root))}
|
|
import requests
|
|
from globals import API_URL
|
|
from utils import DataProcessor
|
|
|
|
def fetch_and_transform_data():
|
|
# Use the global variable for the request
|
|
response = requests.get(API_URL)
|
|
|
|
raw_data = response.text
|
|
|
|
# Use code from another file (utils.py)
|
|
processor = DataProcessor()
|
|
processed = processor.process_data(raw_data)
|
|
transformed = processor.transform_data(processed)
|
|
|
|
return transformed
|
|
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
\"\"\"A class for processing data.\"\"\"
|
|
|
|
number = 1
|
|
|
|
def __repr__(self) -> str:
|
|
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
|
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
|
```
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
class DataTransformer:
|
|
|
|
def transform(self, data):
|
|
self.data = data
|
|
return self.data
|
|
```
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
|
|
def process_data(self, raw_data: str) -> str:
|
|
return raw_data.upper()
|
|
|
|
def transform_data(self, data: str) -> str:
|
|
return DataTransformer().transform(data)
|
|
```
|
|
```python:{path_to_file.relative_to(project_root)}
|
|
def fetch_and_transform_data():
|
|
response = requests.get(API_URL)
|
|
raw_data = response.text
|
|
processor = DataProcessor()
|
|
processed = processor.process_data(raw_data)
|
|
transformed = processor.transform_data(processed)
|
|
return transformed
|
|
```
|
|
"""
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_repo_helper_of_helper_same_class() -> None:
|
|
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
|
path_to_utils = project_root / "utils.py"
|
|
path_to_transform_utils = project_root / "transform_utils.py"
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="transform_data_own_method",
|
|
file_path=str(path_to_utils),
|
|
parents=[FunctionParent(name="DataProcessor", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
|
|
class DataTransformer:
|
|
def __init__(self):
|
|
self.data = None
|
|
|
|
def transform_using_own_method(self, data):
|
|
return self.transform(data)
|
|
|
|
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
|
|
import math
|
|
from transform_utils import DataTransformer
|
|
|
|
class DataProcessor:
|
|
|
|
def __init__(self, default_prefix: str = "PREFIX_"):
|
|
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
|
self.default_prefix = default_prefix
|
|
self.number += math.log(self.number)
|
|
|
|
def transform_data_own_method(self, data: str) -> str:
|
|
\"\"\"Transform the processed data using own method\"\"\"
|
|
return DataTransformer().transform_using_own_method(data)
|
|
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
class DataTransformer:
|
|
|
|
def transform(self, data):
|
|
self.data = data
|
|
return self.data
|
|
```
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
\"\"\"A class for processing data.\"\"\"
|
|
|
|
number = 1
|
|
|
|
def __repr__(self) -> str:
|
|
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
|
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
|
```
|
|
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:transform_utils.py
|
|
class DataTransformer:
|
|
|
|
def transform_using_own_method(self, data):
|
|
return self.transform(data)
|
|
```
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
|
|
def transform_data_own_method(self, data: str) -> str:
|
|
return DataTransformer().transform_using_own_method(data)
|
|
```
|
|
"""
|
|
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_repo_helper_of_helper_same_file() -> None:
|
|
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
|
path_to_utils = project_root / "utils.py"
|
|
path_to_transform_utils = project_root / "transform_utils.py"
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="transform_data_same_file_function",
|
|
file_path=str(path_to_utils),
|
|
parents=[FunctionParent(name="DataProcessor", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
|
|
class DataTransformer:
|
|
def __init__(self):
|
|
self.data = None
|
|
|
|
def transform_using_same_file_function(self, data):
|
|
return update_data(data)
|
|
|
|
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
|
|
import math
|
|
from transform_utils import DataTransformer
|
|
|
|
class DataProcessor:
|
|
|
|
def __init__(self, default_prefix: str = "PREFIX_"):
|
|
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
|
self.default_prefix = default_prefix
|
|
self.number += math.log(self.number)
|
|
|
|
def transform_data_same_file_function(self, data: str) -> str:
|
|
\"\"\"Transform the processed data using a function from the same file\"\"\"
|
|
return DataTransformer().transform_using_same_file_function(data)
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
def update_data(data):
|
|
return data + " updated"
|
|
```
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
\"\"\"A class for processing data.\"\"\"
|
|
|
|
number = 1
|
|
|
|
def __repr__(self) -> str:
|
|
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
|
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
|
```
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:transform_utils.py
|
|
class DataTransformer:
|
|
|
|
def transform_using_same_file_function(self, data):
|
|
return update_data(data)
|
|
```
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
|
|
def transform_data_same_file_function(self, data: str) -> str:
|
|
return DataTransformer().transform_using_same_file_function(data)
|
|
```
|
|
"""
|
|
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_repo_helper_all_same_file() -> None:
|
|
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
|
path_to_transform_utils = project_root / "transform_utils.py"
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="transform_data_all_same_file",
|
|
file_path=str(path_to_transform_utils),
|
|
parents=[FunctionParent(name="DataTransformer", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
|
|
class DataTransformer:
|
|
def __init__(self):
|
|
self.data = None
|
|
|
|
def transform_using_own_method(self, data):
|
|
return self.transform(data)
|
|
|
|
def transform_data_all_same_file(self, data):
|
|
new_data = update_data(data)
|
|
return self.transform_using_own_method(new_data)
|
|
|
|
|
|
def update_data(data):
|
|
return data + " updated"
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
class DataTransformer:
|
|
|
|
def transform(self, data):
|
|
self.data = data
|
|
return self.data
|
|
```
|
|
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
class DataTransformer:
|
|
|
|
def transform_using_own_method(self, data):
|
|
return self.transform(data)
|
|
|
|
def transform_data_all_same_file(self, data):
|
|
new_data = update_data(data)
|
|
return self.transform_using_own_method(new_data)
|
|
|
|
def update_data(data):
|
|
return data + ' updated'
|
|
```
|
|
"""
|
|
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_repo_helper_circular_dependency() -> None:
|
|
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
|
path_to_utils = project_root / "utils.py"
|
|
path_to_transform_utils = project_root / "transform_utils.py"
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="circular_dependency",
|
|
file_path=str(path_to_transform_utils),
|
|
parents=[FunctionParent(name="DataTransformer", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
|
|
import math
|
|
from transform_utils import DataTransformer
|
|
|
|
class DataProcessor:
|
|
|
|
def __init__(self, default_prefix: str = "PREFIX_"):
|
|
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
|
self.default_prefix = default_prefix
|
|
self.number += math.log(self.number)
|
|
|
|
def circular_dependency(self, data: str) -> str:
|
|
\"\"\"Test circular dependency\"\"\"
|
|
return DataTransformer().circular_dependency(data)
|
|
|
|
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
|
|
from code_to_optimize.code_directories.retriever.utils import DataProcessor
|
|
|
|
class DataTransformer:
|
|
def __init__(self):
|
|
self.data = None
|
|
|
|
def circular_dependency(self, data):
|
|
return DataProcessor().circular_dependency(data)
|
|
|
|
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
\"\"\"A class for processing data.\"\"\"
|
|
|
|
number = 1
|
|
|
|
def __repr__(self) -> str:
|
|
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
|
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
|
```
|
|
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:utils.py
|
|
class DataProcessor:
|
|
|
|
def circular_dependency(self, data: str) -> str:
|
|
return DataTransformer().circular_dependency(data)
|
|
```
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
class DataTransformer:
|
|
|
|
def circular_dependency(self, data):
|
|
return DataProcessor().circular_dependency(data)
|
|
```
|
|
"""
|
|
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_indirect_init_helper() -> None:
|
|
code = """
|
|
class MyClass:
|
|
def __init__(self):
|
|
self.x = 1
|
|
self.y = outside_method()
|
|
def target_method(self):
|
|
return self.x + self.y
|
|
|
|
def outside_method():
|
|
return 1
|
|
"""
|
|
with tempfile.NamedTemporaryFile(mode="w") as f:
|
|
f.write(code)
|
|
f.flush()
|
|
file_path = Path(f.name).resolve()
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
|
|
class MyClass:
|
|
def __init__(self):
|
|
self.x = 1
|
|
self.y = outside_method()
|
|
def target_method(self):
|
|
return self.x + self.y
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
def outside_method():
|
|
return 1
|
|
```
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
|
|
def target_method(self):
|
|
return self.x + self.y
|
|
```
|
|
"""
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_direct_module_import() -> None:
|
|
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
|
path_to_main = project_root / "main.py"
|
|
path_to_fto = project_root / "import_test.py"
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="function_to_optimize",
|
|
file_path=str(path_to_fto),
|
|
parents=[],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
expected_read_only_context = """
|
|
```python:utils.py
|
|
from transform_utils import DataTransformer
|
|
|
|
class DataProcessor:
|
|
\"\"\"A class for processing data.\"\"\"
|
|
|
|
number = 1
|
|
|
|
def __repr__(self) -> str:
|
|
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
|
return f"DataProcessor(default_prefix={self.default_prefix!r})"
|
|
|
|
def process_data(self, raw_data: str) -> str:
|
|
\"\"\"Process raw data by converting it to uppercase.\"\"\"
|
|
return raw_data.upper()
|
|
|
|
def transform_data(self, data: str) -> str:
|
|
\"\"\"Transform the processed data\"\"\"
|
|
return DataTransformer().transform(data)
|
|
```"""
|
|
expected_hashing_context = """
|
|
```python:main.py
|
|
def fetch_and_transform_data():
|
|
response = requests.get(API_URL)
|
|
raw_data = response.text
|
|
processor = DataProcessor()
|
|
processed = processor.process_data(raw_data)
|
|
transformed = processor.transform_data(processed)
|
|
return transformed
|
|
```
|
|
```python:import_test.py
|
|
def function_to_optimize():
|
|
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
|
|
```
|
|
"""
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter(path_to_main.relative_to(project_root))}
|
|
import requests
|
|
from globals import API_URL
|
|
from utils import DataProcessor
|
|
|
|
def fetch_and_transform_data():
|
|
# Use the global variable for the request
|
|
response = requests.get(API_URL)
|
|
|
|
raw_data = response.text
|
|
|
|
# Use code from another file (utils.py)
|
|
processor = DataProcessor()
|
|
processed = processor.process_data(raw_data)
|
|
transformed = processor.transform_data(processed)
|
|
|
|
return transformed
|
|
|
|
{get_code_block_splitter(path_to_fto.relative_to(project_root))}
|
|
import code_to_optimize.code_directories.retriever.main
|
|
|
|
def function_to_optimize():
|
|
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
|
|
"""
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_module_import_optimization() -> None:
|
|
main_code = """
|
|
import utility_module
|
|
|
|
class Calculator:
|
|
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
|
|
# This is where we use the imported module
|
|
self.precision = utility_module.select_precision(precision, fallback_precision)
|
|
self.mode = mode
|
|
|
|
# Using variables from the utility module
|
|
self.backend = utility_module.CALCULATION_BACKEND
|
|
self.system = utility_module.SYSTEM_TYPE
|
|
self.default_precision = utility_module.DEFAULT_PRECISION
|
|
|
|
def add(self, a, b):
|
|
return a + b
|
|
|
|
def subtract(self, a, b):
|
|
return a - b
|
|
|
|
def calculate(self, operation, x, y):
|
|
if operation == "add":
|
|
return self.add(x, y)
|
|
elif operation == "subtract":
|
|
return self.subtract(x, y)
|
|
else:
|
|
return None
|
|
"""
|
|
|
|
utility_module_code = """
|
|
import sys
|
|
import platform
|
|
import logging
|
|
|
|
DEFAULT_PRECISION = "medium"
|
|
DEFAULT_MODE = "standard"
|
|
|
|
# Try-except block with variable definitions
|
|
try:
|
|
import numpy as np
|
|
# Used variable in try block
|
|
CALCULATION_BACKEND = "numpy"
|
|
# Unused variable in try block
|
|
VECTOR_DIMENSIONS = 3
|
|
except ImportError:
|
|
# Used variable in except block
|
|
CALCULATION_BACKEND = "python"
|
|
# Unused variable in except block
|
|
FALLBACK_WARNING = "NumPy not available, using slower Python implementation"
|
|
|
|
# Nested if-else with variable definitions
|
|
if sys.platform.startswith('win'):
|
|
# Used variable in outer if
|
|
SYSTEM_TYPE = "windows"
|
|
if platform.architecture()[0] == '64bit':
|
|
# Unused variable in nested if
|
|
MEMORY_MODEL = "x64"
|
|
else:
|
|
# Unused variable in nested else
|
|
MEMORY_MODEL = "x86"
|
|
elif sys.platform.startswith('linux'):
|
|
# Used variable in outer elif
|
|
SYSTEM_TYPE = "linux"
|
|
# Unused variable in outer elif
|
|
KERNEL_VERSION = platform.release()
|
|
else:
|
|
# Used variable in outer else
|
|
SYSTEM_TYPE = "other"
|
|
# Unused variable in outer else
|
|
UNKNOWN_SYSTEM_MSG = "Running on an unrecognized platform"
|
|
|
|
# Function that will be used in the main code
|
|
def select_precision(precision, fallback_precision):
|
|
if precision is None:
|
|
return fallback_precision or DEFAULT_PRECISION
|
|
|
|
# Using the variables defined above
|
|
if CALCULATION_BACKEND == "numpy":
|
|
# Higher precision available with NumPy
|
|
precision_options = ["low", "medium", "high", "ultra"]
|
|
else:
|
|
# Limited precision without NumPy
|
|
precision_options = ["low", "medium", "high"]
|
|
|
|
if isinstance(precision, str):
|
|
if precision.lower() not in precision_options:
|
|
if fallback_precision:
|
|
return fallback_precision
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
return precision.lower()
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
|
|
# Function that won't be used
|
|
def get_system_details():
|
|
return {
|
|
"system": SYSTEM_TYPE,
|
|
"backend": CALCULATION_BACKEND,
|
|
"default_precision": DEFAULT_PRECISION,
|
|
"python_version": sys.version
|
|
}
|
|
"""
|
|
|
|
# Create a temporary directory for the test
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# Set up the package structure
|
|
package_dir = Path(temp_dir) / "package"
|
|
package_dir.mkdir()
|
|
|
|
# Create the __init__.py file
|
|
with open(package_dir / "__init__.py", "w") as init_file:
|
|
init_file.write("")
|
|
|
|
# Write the utility_module.py file
|
|
with open(package_dir / "utility_module.py", "w") as utility_file:
|
|
utility_file.write(utility_module_code)
|
|
utility_file.flush()
|
|
|
|
# Write the main code file
|
|
main_file_path = package_dir / "main_module.py"
|
|
with open(main_file_path, "w") as main_file:
|
|
main_file.write(main_code)
|
|
main_file.flush()
|
|
|
|
# Set up the optimizer
|
|
file_path = main_file_path.resolve()
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=package_dir.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
|
|
# Define the function to optimize
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="calculate",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="Calculator", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
# Get the code optimization context
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
# The expected contexts
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter(main_file_path.relative_to(opt.args.project_root))}
|
|
import utility_module
|
|
|
|
class Calculator:
|
|
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
|
|
# This is where we use the imported module
|
|
self.precision = utility_module.select_precision(precision, fallback_precision)
|
|
self.mode = mode
|
|
|
|
# Using variables from the utility module
|
|
self.backend = utility_module.CALCULATION_BACKEND
|
|
self.system = utility_module.SYSTEM_TYPE
|
|
self.default_precision = utility_module.DEFAULT_PRECISION
|
|
|
|
def add(self, a, b):
|
|
return a + b
|
|
|
|
def subtract(self, a, b):
|
|
return a - b
|
|
|
|
def calculate(self, operation, x, y):
|
|
if operation == "add":
|
|
return self.add(x, y)
|
|
elif operation == "subtract":
|
|
return self.subtract(x, y)
|
|
else:
|
|
return None
|
|
"""
|
|
expected_read_only_context = """
|
|
```python:utility_module.py
|
|
DEFAULT_PRECISION = "medium"
|
|
|
|
# Try-except block with variable definitions
|
|
try:
|
|
# Used variable in try block
|
|
CALCULATION_BACKEND = "numpy"
|
|
except ImportError:
|
|
# Used variable in except block
|
|
CALCULATION_BACKEND = "python"
|
|
|
|
# Function that will be used in the main code
|
|
def select_precision(precision, fallback_precision):
|
|
if precision is None:
|
|
return fallback_precision or DEFAULT_PRECISION
|
|
|
|
# Using the variables defined above
|
|
if CALCULATION_BACKEND == "numpy":
|
|
# Higher precision available with NumPy
|
|
precision_options = ["low", "medium", "high", "ultra"]
|
|
else:
|
|
# Limited precision without NumPy
|
|
precision_options = ["low", "medium", "high"]
|
|
|
|
if isinstance(precision, str):
|
|
if precision.lower() not in precision_options:
|
|
if fallback_precision:
|
|
return fallback_precision
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
return precision.lower()
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
```
|
|
"""
|
|
expected_hashing_context = """
|
|
```python:main_module.py
|
|
class Calculator:
|
|
|
|
def add(self, a, b):
|
|
return a + b
|
|
|
|
def subtract(self, a, b):
|
|
return a - b
|
|
|
|
def calculate(self, operation, x, y):
|
|
if operation == 'add':
|
|
return self.add(x, y)
|
|
elif operation == 'subtract':
|
|
return self.subtract(x, y)
|
|
else:
|
|
return None
|
|
```
|
|
"""
|
|
# Verify the contexts match the expected values
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_module_import_init_fto() -> None:
|
|
main_code = """
|
|
import utility_module
|
|
|
|
class Calculator:
|
|
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
|
|
# This is where we use the imported module
|
|
self.precision = utility_module.select_precision(precision, fallback_precision)
|
|
self.mode = mode
|
|
|
|
# Using variables from the utility module
|
|
self.backend = utility_module.CALCULATION_BACKEND
|
|
self.system = utility_module.SYSTEM_TYPE
|
|
self.default_precision = utility_module.DEFAULT_PRECISION
|
|
|
|
def add(self, a, b):
|
|
return a + b
|
|
|
|
def subtract(self, a, b):
|
|
return a - b
|
|
|
|
def calculate(self, operation, x, y):
|
|
if operation == "add":
|
|
return self.add(x, y)
|
|
elif operation == "subtract":
|
|
return self.subtract(x, y)
|
|
else:
|
|
return None
|
|
"""
|
|
|
|
utility_module_code = """
|
|
import sys
|
|
import platform
|
|
import logging
|
|
|
|
DEFAULT_PRECISION = "medium"
|
|
DEFAULT_MODE = "standard"
|
|
|
|
# Try-except block with variable definitions
|
|
try:
|
|
import numpy as np
|
|
# Used variable in try block
|
|
CALCULATION_BACKEND = "numpy"
|
|
# Unused variable in try block
|
|
VECTOR_DIMENSIONS = 3
|
|
except ImportError:
|
|
# Used variable in except block
|
|
CALCULATION_BACKEND = "python"
|
|
# Unused variable in except block
|
|
FALLBACK_WARNING = "NumPy not available, using slower Python implementation"
|
|
|
|
# Nested if-else with variable definitions
|
|
if sys.platform.startswith('win'):
|
|
# Used variable in outer if
|
|
SYSTEM_TYPE = "windows"
|
|
if platform.architecture()[0] == '64bit':
|
|
# Unused variable in nested if
|
|
MEMORY_MODEL = "x64"
|
|
else:
|
|
# Unused variable in nested else
|
|
MEMORY_MODEL = "x86"
|
|
elif sys.platform.startswith('linux'):
|
|
# Used variable in outer elif
|
|
SYSTEM_TYPE = "linux"
|
|
# Unused variable in outer elif
|
|
KERNEL_VERSION = platform.release()
|
|
else:
|
|
# Used variable in outer else
|
|
SYSTEM_TYPE = "other"
|
|
# Unused variable in outer else
|
|
UNKNOWN_SYSTEM_MSG = "Running on an unrecognized platform"
|
|
|
|
# Function that will be used in the main code
|
|
def select_precision(precision, fallback_precision):
|
|
if precision is None:
|
|
return fallback_precision or DEFAULT_PRECISION
|
|
|
|
# Using the variables defined above
|
|
if CALCULATION_BACKEND == "numpy":
|
|
# Higher precision available with NumPy
|
|
precision_options = ["low", "medium", "high", "ultra"]
|
|
else:
|
|
# Limited precision without NumPy
|
|
precision_options = ["low", "medium", "high"]
|
|
|
|
if isinstance(precision, str):
|
|
if precision.lower() not in precision_options:
|
|
if fallback_precision:
|
|
return fallback_precision
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
return precision.lower()
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
|
|
# Function that won't be used
|
|
def get_system_details():
|
|
return {
|
|
"system": SYSTEM_TYPE,
|
|
"backend": CALCULATION_BACKEND,
|
|
"default_precision": DEFAULT_PRECISION,
|
|
"python_version": sys.version
|
|
}
|
|
"""
|
|
|
|
# Create a temporary directory for the test
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# Set up the package structure
|
|
package_dir = Path(temp_dir) / "package"
|
|
package_dir.mkdir()
|
|
|
|
# Create the __init__.py file
|
|
with open(package_dir / "__init__.py", "w") as init_file:
|
|
init_file.write("")
|
|
|
|
# Write the utility_module.py file
|
|
with open(package_dir / "utility_module.py", "w") as utility_file:
|
|
utility_file.write(utility_module_code)
|
|
utility_file.flush()
|
|
|
|
# Write the main code file
|
|
main_file_path = package_dir / "main_module.py"
|
|
with open(main_file_path, "w") as main_file:
|
|
main_file.write(main_code)
|
|
main_file.flush()
|
|
|
|
# Set up the optimizer
|
|
file_path = main_file_path.resolve()
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=package_dir.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
|
|
# Define the function to optimize
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="__init__",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="Calculator", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
# Get the code optimization context
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
# The expected contexts
|
|
expected_read_write_context = f"""
|
|
{get_code_block_splitter("utility_module.py")}
|
|
# Function that will be used in the main code
|
|
|
|
def select_precision(precision, fallback_precision):
|
|
if precision is None:
|
|
return fallback_precision or DEFAULT_PRECISION
|
|
|
|
# Using the variables defined above
|
|
if CALCULATION_BACKEND == "numpy":
|
|
# Higher precision available with NumPy
|
|
precision_options = ["low", "medium", "high", "ultra"]
|
|
else:
|
|
# Limited precision without NumPy
|
|
precision_options = ["low", "medium", "high"]
|
|
|
|
if isinstance(precision, str):
|
|
if precision.lower() not in precision_options:
|
|
if fallback_precision:
|
|
return fallback_precision
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
return precision.lower()
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
|
|
{get_code_block_splitter(main_file_path.relative_to(opt.args.project_root))}
|
|
import utility_module
|
|
|
|
class Calculator:
|
|
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
|
|
# This is where we use the imported module
|
|
self.precision = utility_module.select_precision(precision, fallback_precision)
|
|
self.mode = mode
|
|
|
|
# Using variables from the utility module
|
|
self.backend = utility_module.CALCULATION_BACKEND
|
|
self.system = utility_module.SYSTEM_TYPE
|
|
self.default_precision = utility_module.DEFAULT_PRECISION
|
|
"""
|
|
expected_read_only_context = """
|
|
```python:utility_module.py
|
|
DEFAULT_PRECISION = "medium"
|
|
|
|
# Try-except block with variable definitions
|
|
try:
|
|
# Used variable in try block
|
|
CALCULATION_BACKEND = "numpy"
|
|
except ImportError:
|
|
# Used variable in except block
|
|
CALCULATION_BACKEND = "python"
|
|
```
|
|
"""
|
|
assert read_write_context.flat.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
|
|
|
|
def test_hashing_code_context_removes_imports_docstrings_and_init() -> None:
|
|
"""Test that hashing context removes imports, docstrings, and __init__ methods properly."""
|
|
code = '''
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
class MyClass:
|
|
"""A class with a docstring."""
|
|
def __init__(self, value):
|
|
"""Initialize with a value."""
|
|
self.value = value
|
|
|
|
def target_method(self):
|
|
"""Target method with docstring."""
|
|
result = self.helper_method()
|
|
helper_cls = HelperClass()
|
|
data = helper_cls.process_data()
|
|
return self.value * 2
|
|
|
|
def helper_method(self):
|
|
"""Helper method with docstring."""
|
|
return self.value + 1
|
|
|
|
class HelperClass:
|
|
"""Helper class docstring."""
|
|
def __init__(self):
|
|
"""Helper init method."""
|
|
self.data = "test"
|
|
|
|
def process_data(self):
|
|
"""Process data method."""
|
|
return self.data.upper()
|
|
|
|
def standalone_function():
|
|
"""Standalone function."""
|
|
return "standalone"
|
|
'''
|
|
with tempfile.NamedTemporaryFile(mode="w") as f:
|
|
f.write(code)
|
|
f.flush()
|
|
file_path = Path(f.name).resolve()
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
# Expected behavior based on current implementation:
|
|
# - Should not contain imports
|
|
# - Should remove docstrings from target functions (but currently doesn't - this is a bug)
|
|
# - Should not contain __init__ methods
|
|
# - Should contain target function and helper methods that are actually called
|
|
# - Should be formatted as markdown
|
|
|
|
# Test that it's formatted as markdown
|
|
assert hashing_context.startswith("```python:")
|
|
assert hashing_context.endswith("```")
|
|
|
|
# Test basic structure requirements
|
|
assert "import" not in hashing_context # Should not contain imports
|
|
assert "__init__" not in hashing_context # Should not contain __init__ methods
|
|
assert "target_method" in hashing_context # Should contain target function
|
|
assert "standalone_function" not in hashing_context # Should not contain unused functions
|
|
|
|
# Test that helper functions are included when they're called
|
|
assert "helper_method" in hashing_context # Should contain called helper method
|
|
assert "process_data" in hashing_context # Should contain called helper method
|
|
|
|
# Test for docstring removal (this should pass when implementation is fixed)
|
|
# Currently this will fail because docstrings are not being removed properly
|
|
assert '"""Target method with docstring."""' not in hashing_context, (
|
|
"Docstrings should be removed from target functions"
|
|
)
|
|
assert '"""Helper method with docstring."""' not in hashing_context, (
|
|
"Docstrings should be removed from helper functions"
|
|
)
|
|
assert '"""Process data method."""' not in hashing_context, (
|
|
"Docstrings should be removed from helper class methods"
|
|
)
|
|
|
|
|
|
def test_hashing_code_context_with_nested_classes() -> None:
|
|
"""Test that hashing context handles nested classes properly (should exclude them)."""
|
|
code = '''
|
|
class OuterClass:
|
|
"""Outer class docstring."""
|
|
def __init__(self):
|
|
"""Outer init."""
|
|
self.value = 1
|
|
|
|
def target_method(self):
|
|
"""Target method."""
|
|
return self.NestedClass().nested_method()
|
|
|
|
class NestedClass:
|
|
"""Nested class - should be excluded."""
|
|
def __init__(self):
|
|
self.nested_value = 2
|
|
|
|
def nested_method(self):
|
|
return self.nested_value
|
|
'''
|
|
with tempfile.NamedTemporaryFile(mode="w") as f:
|
|
f.write(code)
|
|
f.flush()
|
|
file_path = Path(f.name).resolve()
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="OuterClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
# Test basic requirements
|
|
assert hashing_context.startswith("```python:")
|
|
assert hashing_context.endswith("```")
|
|
assert "target_method" in hashing_context
|
|
assert "__init__" not in hashing_context # Should not contain __init__ methods
|
|
|
|
# Verify nested classes are excluded from the hashing context
|
|
# The prune_cst_for_code_hashing function should not recurse into nested classes
|
|
assert "class NestedClass:" not in hashing_context # Nested class definition should not be present
|
|
|
|
# The target method will reference NestedClass, but the actual nested class definition should not be included
|
|
# The call to self.NestedClass().nested_method() should be in the target method but the nested class itself excluded
|
|
target_method_call_present = "self.NestedClass().nested_method()" in hashing_context
|
|
assert target_method_call_present, "The target method should contain the call to nested class"
|
|
|
|
# But the actual nested method definition should not be present
|
|
nested_method_definition_present = "def nested_method(self):" in hashing_context
|
|
assert not nested_method_definition_present, "Nested method definition should not be present in hashing context"
|
|
|
|
|
|
def test_hashing_code_context_hash_consistency() -> None:
|
|
"""Test that the same code produces the same hash."""
|
|
code = """
|
|
class TestClass:
|
|
def target_method(self):
|
|
return "test"
|
|
"""
|
|
with tempfile.NamedTemporaryFile(mode="w") as f:
|
|
f.write(code)
|
|
f.flush()
|
|
file_path = Path(f.name).resolve()
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="TestClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
# Generate context twice
|
|
code_ctx1 = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
code_ctx2 = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
|
|
# Hash should be consistent
|
|
assert code_ctx1.hashing_code_context_hash == code_ctx2.hashing_code_context_hash
|
|
assert code_ctx1.hashing_code_context == code_ctx2.hashing_code_context
|
|
|
|
# Hash should be valid SHA256
|
|
import hashlib
|
|
|
|
expected_hash = hashlib.sha256(code_ctx1.hashing_code_context.encode("utf-8")).hexdigest()
|
|
assert code_ctx1.hashing_code_context_hash == expected_hash
|
|
|
|
|
|
def test_hashing_code_context_different_code_different_hash() -> None:
|
|
"""Test that different code produces different hashes."""
|
|
code1 = """
|
|
class TestClass:
|
|
def target_method(self):
|
|
return "test1"
|
|
"""
|
|
code2 = """
|
|
class TestClass:
|
|
def target_method(self):
|
|
return "test2"
|
|
"""
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w") as f1, tempfile.NamedTemporaryFile(mode="w") as f2:
|
|
f1.write(code1)
|
|
f1.flush()
|
|
f2.write(code2)
|
|
f2.flush()
|
|
|
|
file_path1 = Path(f1.name).resolve()
|
|
file_path2 = Path(f2.name).resolve()
|
|
|
|
opt1 = Optimizer(
|
|
Namespace(
|
|
project_root=file_path1.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
opt2 = Optimizer(
|
|
Namespace(
|
|
project_root=file_path2.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
|
|
function_to_optimize1 = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path1,
|
|
parents=[FunctionParent(name="TestClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
function_to_optimize2 = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path2,
|
|
parents=[FunctionParent(name="TestClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx1 = get_code_optimization_context(function_to_optimize1, opt1.args.project_root)
|
|
code_ctx2 = get_code_optimization_context(function_to_optimize2, opt2.args.project_root)
|
|
|
|
# Different code should produce different hashes
|
|
assert code_ctx1.hashing_code_context_hash != code_ctx2.hashing_code_context_hash
|
|
assert code_ctx1.hashing_code_context != code_ctx2.hashing_code_context
|
|
|
|
|
|
def test_hashing_code_context_format_is_markdown() -> None:
|
|
"""Test that hashing context is formatted as markdown."""
|
|
code = """
|
|
class SimpleClass:
|
|
def simple_method(self):
|
|
return 42
|
|
"""
|
|
with tempfile.NamedTemporaryFile(mode="w") as f:
|
|
f.write(code)
|
|
f.flush()
|
|
file_path = Path(f.name).resolve()
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="simple_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="SimpleClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
# Should be formatted as markdown code block
|
|
assert hashing_context.startswith("```python:")
|
|
assert hashing_context.endswith("```")
|
|
|
|
# Should contain the relative file path in the markdown header
|
|
relative_path = file_path.relative_to(opt.args.project_root)
|
|
assert str(relative_path) in hashing_context
|
|
|
|
# Should contain the actual code between the markdown markers
|
|
lines = hashing_context.strip().split("\n")
|
|
assert lines[0].startswith("```python:")
|
|
assert lines[-1] == "```"
|
|
|
|
# Code should be between the markers
|
|
code_lines = lines[1:-1]
|
|
code_content = "\n".join(code_lines)
|
|
assert "class SimpleClass:" in code_content
|
|
assert "def simple_method(self):" in code_content
|
|
assert "return 42" in code_content
|
|
|
|
|
|
# This shouldn't happen as we are now using a scoped optimization context, but keep it just in case
|
|
def test_circular_deps():
|
|
path_to_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "circular_deps"
|
|
file_abs_path = path_to_root / "api_client.py"
|
|
optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8")
|
|
content = Path(file_abs_path).read_text(encoding="utf-8")
|
|
new_code = replace_functions_and_add_imports(
|
|
source_code= add_global_assignments(optimized_code, content),
|
|
function_names= ["ApiClient.get_console_url"],
|
|
optimized_code= optimized_code,
|
|
module_abspath= Path(file_abs_path),
|
|
preexisting_objects= {('ApiClient', ()), ('get_console_url', (FunctionParent(name='ApiClient', type='ClassDef'),))},
|
|
project_root_path= Path(path_to_root),
|
|
)
|
|
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
|
|
|
|
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
|