mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
GlobalAssignmentCollector only handled cst.Assign but not cst.AnnAssign (annotated assignments like `X: int = 1`). When the LLM generated optimizations with annotated module-level variables, these weren't copied to the target file, causing NameError at runtime. - Add visit_AnnAssign to GlobalAssignmentCollector - Add leave_AnnAssign to GlobalAssignmentTransformer - Update type hints to include cst.AnnAssign - Add test for annotated assignment handling
3537 lines
119 KiB
Python
3537 lines
119 KiB
Python
from __future__ import annotations
|
|
|
|
import sys
|
|
import tempfile
|
|
from argparse import Namespace
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_global_assignments
|
|
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
|
from codeflash.context.code_context_extractor import (
|
|
collect_names_from_annotation,
|
|
extract_imports_for_class,
|
|
get_code_optimization_context,
|
|
get_imported_class_definitions,
|
|
)
|
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
|
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
|
|
from codeflash.optimization.optimizer import Optimizer
|
|
|
|
|
|
class HelperClass:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def innocent_bystander(self):
|
|
pass
|
|
|
|
def helper_method(self):
|
|
return self.name
|
|
|
|
class NestedClass:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def nested_method(self):
|
|
return self.name
|
|
|
|
|
|
def main_method():
|
|
return "hello"
|
|
|
|
|
|
class MainClass:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def main_method(self):
|
|
self.name = HelperClass.NestedClass("test").nested_method()
|
|
return HelperClass(self.name).helper_method()
|
|
|
|
|
|
class Graph:
|
|
def __init__(self, vertices):
|
|
self.graph = defaultdict(list)
|
|
self.V = vertices # No. of vertices
|
|
|
|
def addEdge(self, u, v):
|
|
self.graph[u].append(v)
|
|
|
|
def topologicalSortUtil(self, v, visited, stack):
|
|
visited[v] = True
|
|
|
|
for i in self.graph[v]:
|
|
if visited[i] == False:
|
|
self.topologicalSortUtil(i, visited, stack)
|
|
|
|
stack.insert(0, v)
|
|
|
|
def topologicalSort(self):
|
|
visited = [False] * self.V
|
|
stack = []
|
|
|
|
for i in range(self.V):
|
|
if visited[i] == False:
|
|
self.topologicalSortUtil(i, visited, stack)
|
|
|
|
# Print contents of stack
|
|
return stack
|
|
|
|
|
|
def test_code_replacement10() -> None:
|
|
file_path = Path(__file__).resolve()
|
|
|
|
func_top_optimize = FunctionToOptimize(
|
|
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
|
|
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
|
|
# HelperClass.__init__ is now tracked because HelperClass(self.name) instantiates the class
|
|
assert qualified_names == {
|
|
"HelperClass.helper_method",
|
|
"HelperClass.__init__",
|
|
} # Nested method should not be in here
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
expected_read_write_context = f"""
|
|
```python:{file_path.relative_to(file_path.parent)}
|
|
from __future__ import annotations
|
|
|
|
class HelperClass:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def helper_method(self):
|
|
return self.name
|
|
|
|
|
|
class MainClass:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def main_method(self):
|
|
self.name = HelperClass.NestedClass("test").nested_method()
|
|
return HelperClass(self.name).helper_method()
|
|
```
|
|
"""
|
|
expected_read_only_context = """
|
|
"""
|
|
|
|
expected_hashing_context = f"""
|
|
```python:{file_path.relative_to(file_path.parent)}
|
|
class HelperClass:
|
|
|
|
def helper_method(self):
|
|
return self.name
|
|
|
|
class MainClass:
|
|
|
|
def main_method(self):
|
|
self.name = HelperClass.NestedClass('test').nested_method()
|
|
return HelperClass(self.name).helper_method()
|
|
```
|
|
"""
|
|
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_class_method_dependencies() -> None:
|
|
file_path = Path(__file__).resolve()
|
|
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="topologicalSort",
|
|
file_path=str(file_path),
|
|
parents=[FunctionParent(name="Graph", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, file_path.parent.resolve())
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
expected_read_write_context = f"""
|
|
```python:{file_path.relative_to(file_path.parent)}
|
|
from __future__ import annotations
|
|
from collections import defaultdict
|
|
|
|
class Graph:
|
|
def __init__(self, vertices):
|
|
self.graph = defaultdict(list)
|
|
self.V = vertices # No. of vertices
|
|
|
|
def topologicalSortUtil(self, v, visited, stack):
|
|
visited[v] = True
|
|
|
|
for i in self.graph[v]:
|
|
if visited[i] == False:
|
|
self.topologicalSortUtil(i, visited, stack)
|
|
|
|
stack.insert(0, v)
|
|
|
|
def topologicalSort(self):
|
|
visited = [False] * self.V
|
|
stack = []
|
|
|
|
for i in range(self.V):
|
|
if visited[i] == False:
|
|
self.topologicalSortUtil(i, visited, stack)
|
|
|
|
# Print contents of stack
|
|
return stack
|
|
```
|
|
"""
|
|
expected_read_only_context = ""
|
|
|
|
expected_hashing_context = f"""
|
|
```python:{file_path.relative_to(file_path.parent.resolve())}
|
|
class Graph:
|
|
|
|
def topologicalSortUtil(self, v, visited, stack):
|
|
visited[v] = True
|
|
for i in self.graph[v]:
|
|
if visited[i] == False:
|
|
self.topologicalSortUtil(i, visited, stack)
|
|
stack.insert(0, v)
|
|
|
|
def topologicalSort(self):
|
|
visited = [False] * self.V
|
|
stack = []
|
|
for i in range(self.V):
|
|
if visited[i] == False:
|
|
self.topologicalSortUtil(i, visited, stack)
|
|
return stack
|
|
```
|
|
"""
|
|
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_bubble_sort_helper() -> None:
|
|
path_to_fto = (
|
|
Path(__file__).resolve().parent.parent
|
|
/ "code_to_optimize"
|
|
/ "code_directories"
|
|
/ "retriever"
|
|
/ "bubble_sort_imported.py"
|
|
)
|
|
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="sort_from_another_file",
|
|
file_path=str(path_to_fto),
|
|
parents=[],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, Path(__file__).resolve().parent.parent)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
expected_read_write_context = """
|
|
```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py
|
|
import math
|
|
|
|
def sorter(arr):
|
|
arr.sort()
|
|
x = math.sqrt(2)
|
|
print(x)
|
|
return arr
|
|
```
|
|
```python:code_to_optimize/code_directories/retriever/bubble_sort_imported.py
|
|
from bubble_sort_with_math import sorter
|
|
|
|
def sort_from_another_file(arr):
|
|
sorted_arr = sorter(arr)
|
|
return sorted_arr
|
|
```
|
|
"""
|
|
expected_read_only_context = ""
|
|
|
|
expected_hashing_context = """
|
|
```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py
|
|
def sorter(arr):
|
|
arr.sort()
|
|
x = math.sqrt(2)
|
|
print(x)
|
|
return arr
|
|
```
|
|
```python:code_to_optimize/code_directories/retriever/bubble_sort_imported.py
|
|
def sort_from_another_file(arr):
|
|
sorted_arr = sorter(arr)
|
|
return sorted_arr
|
|
```
|
|
"""
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_flavio_typed_code_helper(tmp_path: Path) -> None:
|
|
code = '''
|
|
|
|
_P = ParamSpec("_P")
|
|
_KEY_T = TypeVar("_KEY_T")
|
|
_STORE_T = TypeVar("_STORE_T")
|
|
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
|
"""Interface for cache backends used by the persistent cache decorator."""
|
|
|
|
def __init__(self) -> None: ...
|
|
|
|
def hash_key(
|
|
self,
|
|
*,
|
|
func: Callable[_P, Any],
|
|
args: tuple[Any, ...],
|
|
kwargs: dict[str, Any],
|
|
) -> tuple[str, _KEY_T]: ...
|
|
|
|
def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401
|
|
...
|
|
|
|
def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401
|
|
...
|
|
|
|
def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ...
|
|
|
|
def delete(self, *, key: tuple[str, _KEY_T]) -> None: ...
|
|
|
|
def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ...
|
|
|
|
def get_cache_or_call(
|
|
self,
|
|
*,
|
|
func: Callable[_P, Any],
|
|
args: tuple[Any, ...],
|
|
kwargs: dict[str, Any],
|
|
lifespan: datetime.timedelta,
|
|
) -> Any: # noqa: ANN401
|
|
"""
|
|
Retrieve the cached results for a function call.
|
|
|
|
Args:
|
|
----
|
|
func (Callable[..., _R]): The function to retrieve cached results for.
|
|
args (tuple[Any, ...]): The positional arguments passed to the function.
|
|
kwargs (dict[str, Any]): The keyword arguments passed to the function.
|
|
lifespan (datetime.timedelta): The maximum age of the cached results.
|
|
|
|
Returns:
|
|
-------
|
|
_R: The cached results, if available.
|
|
|
|
"""
|
|
if os.environ.get("NO_CACHE"):
|
|
return func(*args, **kwargs)
|
|
|
|
try:
|
|
key = self.hash_key(func=func, args=args, kwargs=kwargs)
|
|
except: # noqa: E722
|
|
# If we can't create a cache key, we should just call the function.
|
|
logging.warning("Failed to hash cache key for function: %s", func)
|
|
return func(*args, **kwargs)
|
|
result_pair = self.get(key=key)
|
|
|
|
if result_pair is not None:
|
|
cached_time, result = result_pair
|
|
if not os.environ.get("RE_CACHE") and (
|
|
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
|
|
):
|
|
try:
|
|
return self.decode(data=result)
|
|
except CacheBackendDecodeError as e:
|
|
logging.warning("Failed to decode cache data: %s", e)
|
|
# If decoding fails we will treat this as a cache miss.
|
|
# This might happens if underlying class definition of the data changes.
|
|
self.delete(key=key)
|
|
result = func(*args, **kwargs)
|
|
try:
|
|
self.put(key=key, data=self.encode(data=result))
|
|
except CacheBackendEncodeError as e:
|
|
logging.warning("Failed to encode cache data: %s", e)
|
|
# If encoding fails, we should still return the result.
|
|
return result
|
|
|
|
_P = ParamSpec("_P")
|
|
_R = TypeVar("_R")
|
|
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
|
|
|
|
|
|
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|
"""
|
|
A decorator class that provides persistent caching functionality for a function.
|
|
|
|
Args:
|
|
----
|
|
func (Callable[_P, _R]): The function to be decorated.
|
|
duration (datetime.timedelta): The duration for which the cached results should be considered valid.
|
|
backend (_backend): The backend storage for the cached results.
|
|
|
|
Attributes:
|
|
----------
|
|
__wrapped__ (Callable[_P, _R]): The wrapped function.
|
|
__duration__ (datetime.timedelta): The duration for which the cached results should be considered valid.
|
|
__backend__ (_backend): The backend storage for the cached results.
|
|
|
|
""" # noqa: E501
|
|
|
|
__wrapped__: Callable[_P, _R]
|
|
__duration__: datetime.timedelta
|
|
__backend__: _CacheBackendT
|
|
|
|
def __init__(
|
|
self,
|
|
func: Callable[_P, _R],
|
|
duration: datetime.timedelta,
|
|
) -> None:
|
|
self.__wrapped__ = func
|
|
self.__duration__ = duration
|
|
self.__backend__ = AbstractCacheBackend()
|
|
functools.update_wrapper(self, func)
|
|
|
|
def cache_clear(self) -> None:
|
|
"""Clears the cache for the wrapped function."""
|
|
self.__backend__.del_func_cache(func=self.__wrapped__)
|
|
|
|
def no_cache_call(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
"""
|
|
Calls the wrapped function without using the cache.
|
|
|
|
Args:
|
|
----
|
|
*args (_P.args): Positional arguments for the wrapped function.
|
|
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
|
|
|
|
Returns:
|
|
-------
|
|
_R: The result of the wrapped function.
|
|
|
|
"""
|
|
return self.__wrapped__(*args, **kwargs)
|
|
|
|
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
"""
|
|
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
|
|
|
|
Args:
|
|
----
|
|
*args (_P.args): Positional arguments for the wrapped function.
|
|
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
|
|
|
|
Returns:
|
|
-------
|
|
_R: The result of the wrapped function.
|
|
|
|
""" # noqa: E501
|
|
if "NO_CACHE" in os.environ:
|
|
return self.__wrapped__(*args, **kwargs)
|
|
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
|
|
return self.__backend__.get_cache_or_call(
|
|
func=self.__wrapped__,
|
|
args=args,
|
|
kwargs=kwargs,
|
|
lifespan=self.__duration__,
|
|
)
|
|
'''
|
|
# Create a temporary Python file using pytest's tmp_path fixture
|
|
file_path = tmp_path / "test_code.py"
|
|
file_path.write_text(code, encoding="utf-8")
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="__call__",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="_PersistentCache", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
_P = ParamSpec("_P")
|
|
_KEY_T = TypeVar("_KEY_T")
|
|
_STORE_T = TypeVar("_STORE_T")
|
|
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
|
|
|
def __init__(self) -> None: ...
|
|
|
|
def get_cache_or_call(
|
|
self,
|
|
*,
|
|
func: Callable[_P, Any],
|
|
args: tuple[Any, ...],
|
|
kwargs: dict[str, Any],
|
|
lifespan: datetime.timedelta,
|
|
) -> Any: # noqa: ANN401
|
|
\"\"\"
|
|
Retrieve the cached results for a function call.
|
|
|
|
Args:
|
|
----
|
|
func (Callable[..., _R]): The function to retrieve cached results for.
|
|
args (tuple[Any, ...]): The positional arguments passed to the function.
|
|
kwargs (dict[str, Any]): The keyword arguments passed to the function.
|
|
lifespan (datetime.timedelta): The maximum age of the cached results.
|
|
|
|
Returns:
|
|
-------
|
|
_R: The cached results, if available.
|
|
|
|
\"\"\"
|
|
if os.environ.get("NO_CACHE"):
|
|
return func(*args, **kwargs)
|
|
|
|
try:
|
|
key = self.hash_key(func=func, args=args, kwargs=kwargs)
|
|
except: # noqa: E722
|
|
# If we can't create a cache key, we should just call the function.
|
|
logging.warning("Failed to hash cache key for function: %s", func)
|
|
return func(*args, **kwargs)
|
|
result_pair = self.get(key=key)
|
|
|
|
if result_pair is not None:
|
|
cached_time, result = result_pair
|
|
if not os.environ.get("RE_CACHE") and (
|
|
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
|
|
):
|
|
try:
|
|
return self.decode(data=result)
|
|
except CacheBackendDecodeError as e:
|
|
logging.warning("Failed to decode cache data: %s", e)
|
|
# If decoding fails we will treat this as a cache miss.
|
|
# This might happens if underlying class definition of the data changes.
|
|
self.delete(key=key)
|
|
result = func(*args, **kwargs)
|
|
try:
|
|
self.put(key=key, data=self.encode(data=result))
|
|
except CacheBackendEncodeError as e:
|
|
logging.warning("Failed to encode cache data: %s", e)
|
|
# If encoding fails, we should still return the result.
|
|
return result
|
|
|
|
_P = ParamSpec("_P")
|
|
_R = TypeVar("_R")
|
|
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
|
|
|
|
|
|
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|
|
|
def __init__(
|
|
self,
|
|
func: Callable[_P, _R],
|
|
duration: datetime.timedelta,
|
|
) -> None:
|
|
self.__wrapped__ = func
|
|
self.__duration__ = duration
|
|
self.__backend__ = AbstractCacheBackend()
|
|
functools.update_wrapper(self, func)
|
|
|
|
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
\"\"\"
|
|
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
|
|
|
|
Args:
|
|
----
|
|
*args (_P.args): Positional arguments for the wrapped function.
|
|
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
|
|
|
|
Returns:
|
|
-------
|
|
_R: The result of the wrapped function.
|
|
|
|
\"\"\" # noqa: E501
|
|
if "NO_CACHE" in os.environ:
|
|
return self.__wrapped__(*args, **kwargs)
|
|
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
|
|
return self.__backend__.get_cache_or_call(
|
|
func=self.__wrapped__,
|
|
args=args,
|
|
kwargs=kwargs,
|
|
lifespan=self.__duration__,
|
|
)
|
|
```
|
|
"""
|
|
expected_read_only_context = f'''
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
_P = ParamSpec("_P")
|
|
_KEY_T = TypeVar("_KEY_T")
|
|
_STORE_T = TypeVar("_STORE_T")
|
|
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
|
"""Interface for cache backends used by the persistent cache decorator."""
|
|
|
|
def __init__(self) -> None: ...
|
|
|
|
def hash_key(
|
|
self,
|
|
*,
|
|
func: Callable[_P, Any],
|
|
args: tuple[Any, ...],
|
|
kwargs: dict[str, Any],
|
|
) -> tuple[str, _KEY_T]: ...
|
|
|
|
def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401
|
|
...
|
|
|
|
def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401
|
|
...
|
|
|
|
def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ...
|
|
|
|
def delete(self, *, key: tuple[str, _KEY_T]) -> None: ...
|
|
|
|
def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ...
|
|
|
|
_P = ParamSpec("_P")
|
|
_R = TypeVar("_R")
|
|
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
|
|
|
|
|
|
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|
"""
|
|
A decorator class that provides persistent caching functionality for a function.
|
|
|
|
Args:
|
|
----
|
|
func (Callable[_P, _R]): The function to be decorated.
|
|
duration (datetime.timedelta): The duration for which the cached results should be considered valid.
|
|
backend (_backend): The backend storage for the cached results.
|
|
|
|
Attributes:
|
|
----------
|
|
__wrapped__ (Callable[_P, _R]): The wrapped function.
|
|
__duration__ (datetime.timedelta): The duration for which the cached results should be considered valid.
|
|
__backend__ (_backend): The backend storage for the cached results.
|
|
|
|
""" # noqa: E501
|
|
|
|
__wrapped__: Callable[_P, _R]
|
|
__duration__: datetime.timedelta
|
|
__backend__: _CacheBackendT
|
|
```
|
|
'''
|
|
expected_hashing_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
|
|
|
def get_cache_or_call(self, *, func: Callable[_P, Any], args: tuple[Any, ...], kwargs: dict[str, Any], lifespan: datetime.timedelta) -> Any:
|
|
if os.environ.get('NO_CACHE'):
|
|
return func(*args, **kwargs)
|
|
try:
|
|
key = self.hash_key(func=func, args=args, kwargs=kwargs)
|
|
except:
|
|
logging.warning('Failed to hash cache key for function: %s', func)
|
|
return func(*args, **kwargs)
|
|
result_pair = self.get(key=key)
|
|
if result_pair is not None:
|
|
{"cached_time, result = result_pair" if sys.version_info >= (3, 11) else "(cached_time, result) = result_pair"}
|
|
if not os.environ.get('RE_CACHE') and datetime.datetime.now() < cached_time + lifespan:
|
|
try:
|
|
return self.decode(data=result)
|
|
except CacheBackendDecodeError as e:
|
|
logging.warning('Failed to decode cache data: %s', e)
|
|
self.delete(key=key)
|
|
result = func(*args, **kwargs)
|
|
try:
|
|
self.put(key=key, data=self.encode(data=result))
|
|
except CacheBackendEncodeError as e:
|
|
logging.warning('Failed to encode cache data: %s', e)
|
|
return result
|
|
|
|
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|
|
|
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
|
if 'NO_CACHE' in os.environ:
|
|
return self.__wrapped__(*args, **kwargs)
|
|
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
|
|
return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__)
|
|
```
|
|
"""
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_example_class(tmp_path: Path) -> None:
|
|
code = """
|
|
class MyClass:
|
|
\"\"\"A class with a helper method.\"\"\"
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
\"\"\"A helper class for MyClass.\"\"\"
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def __repr__(self):
|
|
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
|
return "HelperClass" + str(self.x)
|
|
def helper_method(self):
|
|
return self.x
|
|
"""
|
|
# Create a temporary Python file using pytest's tmp_path fixture
|
|
file_path = tmp_path / "test_code.py"
|
|
file_path.write_text(code, encoding="utf-8")
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
expected_read_write_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def helper_method(self):
|
|
return self.x
|
|
```
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
\"\"\"A class with a helper method.\"\"\"
|
|
|
|
class HelperClass:
|
|
\"\"\"A helper class for MyClass.\"\"\"
|
|
def __repr__(self):
|
|
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
|
return "HelperClass" + str(self.x)
|
|
```
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
|
|
def target_method(self):
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
|
|
def helper_method(self):
|
|
return self.x
|
|
```
|
|
"""
|
|
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_example_class_token_limit_1(tmp_path: Path) -> None:
|
|
docstring_filler = " ".join(
|
|
["This is a long docstring that will be used to fill up the token limit." for _ in range(1000)]
|
|
)
|
|
code = f"""
|
|
class MyClass:
|
|
\"\"\"A class with a helper method.
|
|
{docstring_filler}\"\"\"
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
\"\"\"Docstring for target method\"\"\"
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
\"\"\"A helper class for MyClass.\"\"\"
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def __repr__(self):
|
|
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
|
return "HelperClass" + str(self.x)
|
|
def helper_method(self):
|
|
return self.x
|
|
"""
|
|
# Create a temporary Python file using pytest's tmp_path fixture
|
|
file_path = tmp_path / "test_code.py"
|
|
file_path.write_text(code, encoding="utf-8")
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
|
|
expected_read_write_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
\"\"\"Docstring for target method\"\"\"
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def helper_method(self):
|
|
return self.x
|
|
```
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
pass
|
|
|
|
class HelperClass:
|
|
def __repr__(self):
|
|
return "HelperClass" + str(self.x)
|
|
```
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
|
|
def target_method(self):
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
|
|
def helper_method(self):
|
|
return self.x
|
|
```
|
|
"""
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_example_class_token_limit_2(tmp_path: Path) -> None:
|
|
string_filler = " ".join(
|
|
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
|
|
)
|
|
code = f"""
|
|
class MyClass:
|
|
\"\"\"A class with a helper method. \"\"\"
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
\"\"\"Docstring for target method\"\"\"
|
|
y = HelperClass().helper_method()
|
|
x = '{string_filler}'
|
|
|
|
class HelperClass:
|
|
\"\"\"A helper class for MyClass.\"\"\"
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def __repr__(self):
|
|
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
|
return "HelperClass" + str(self.x)
|
|
def helper_method(self):
|
|
return self.x
|
|
"""
|
|
# Create a temporary Python file using pytest's tmp_path fixture
|
|
file_path = tmp_path / "test_code.py"
|
|
file_path.write_text(code, encoding="utf-8")
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
|
|
expected_read_write_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
\"\"\"Docstring for target method\"\"\"
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def helper_method(self):
|
|
return self.x
|
|
```
|
|
"""
|
|
expected_read_only_context = f'''```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
"""A class with a helper method. """
|
|
|
|
class HelperClass:
|
|
"""A helper class for MyClass."""
|
|
def __repr__(self):
|
|
"""Return a string representation of the HelperClass."""
|
|
return "HelperClass" + str(self.x)
|
|
```
|
|
'''
|
|
expected_hashing_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
|
|
def target_method(self):
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
|
|
def helper_method(self):
|
|
return self.x
|
|
```
|
|
"""
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_example_class_token_limit_3(tmp_path: Path) -> None:
|
|
string_filler = " ".join(
|
|
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
|
|
)
|
|
code = f"""
|
|
class MyClass:
|
|
\"\"\"A class with a helper method. \"\"\"
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
\"\"\"{string_filler}\"\"\"
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
\"\"\"A helper class for MyClass.\"\"\"
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def __repr__(self):
|
|
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
|
return "HelperClass" + str(self.x)
|
|
def helper_method(self):
|
|
return self.x
|
|
"""
|
|
# Create a temporary Python file using pytest's tmp_path fixture
|
|
file_path = tmp_path / "test_code.py"
|
|
file_path.write_text(code, encoding="utf-8")
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
# In this scenario, the read-writable code is too long, so we abort.
|
|
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
|
|
|
|
def test_example_class_token_limit_4(tmp_path: Path) -> None:
|
|
string_filler = " ".join(
|
|
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
|
|
)
|
|
code = f"""
|
|
class MyClass:
|
|
\"\"\"A class with a helper method. \"\"\"
|
|
def __init__(self):
|
|
global x
|
|
x = 1
|
|
def target_method(self):
|
|
\"\"\"Docstring for target method\"\"\"
|
|
y = HelperClass().helper_method()
|
|
x = '{string_filler}'
|
|
|
|
class HelperClass:
|
|
\"\"\"A helper class for MyClass.\"\"\"
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def __repr__(self):
|
|
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
|
return "HelperClass" + str(self.x)
|
|
def helper_method(self):
|
|
return self.x
|
|
"""
|
|
# Create a temporary Python file using pytest's tmp_path fixture
|
|
file_path = tmp_path / "test_code.py"
|
|
file_path.write_text(code, encoding="utf-8")
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
# In this scenario, the read-writable code context becomes too large because the __init__ function is referencing the global x variable instead of the class attribute self.x, so we abort.
|
|
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
|
|
|
|
def test_example_class_token_limit_5(tmp_path: Path) -> None:
|
|
string_filler = " ".join(
|
|
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
|
|
)
|
|
code = f"""
|
|
class MyClass:
|
|
\"\"\"A class with a helper method. \"\"\"
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
\"\"\"Docstring for target method\"\"\"
|
|
y = HelperClass().helper_method()
|
|
x = '{string_filler}'
|
|
|
|
class HelperClass:
|
|
\"\"\"A helper class for MyClass.\"\"\"
|
|
def __init__(self):
|
|
\"\"\"Initialize the HelperClass.\"\"\"
|
|
self.x = 1
|
|
def __repr__(self):
|
|
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
|
return "HelperClass" + str(self.x)
|
|
def helper_method(self):
|
|
return self.x
|
|
"""
|
|
# Create a temporary Python file using pytest's tmp_path fixture
|
|
file_path = tmp_path / "test_code.py"
|
|
file_path.write_text(code, encoding="utf-8")
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
|
|
# the global x variable shouldn't be included in any context type
|
|
assert (
|
|
code_ctx.read_writable_code.flat
|
|
== '''# file: test_code.py
|
|
class MyClass:
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
"""Docstring for target method"""
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
def __init__(self):
|
|
"""Initialize the HelperClass."""
|
|
self.x = 1
|
|
def helper_method(self):
|
|
return self.x
|
|
'''
|
|
)
|
|
assert (
|
|
code_ctx.testgen_context.flat
|
|
== '''# file: test_code.py
|
|
class MyClass:
|
|
"""A class with a helper method. """
|
|
def __init__(self):
|
|
self.x = 1
|
|
def target_method(self):
|
|
"""Docstring for target method"""
|
|
y = HelperClass().helper_method()
|
|
|
|
class HelperClass:
|
|
"""A helper class for MyClass."""
|
|
def __init__(self):
|
|
"""Initialize the HelperClass."""
|
|
self.x = 1
|
|
def __repr__(self):
|
|
"""Return a string representation of the HelperClass."""
|
|
return "HelperClass" + str(self.x)
|
|
def helper_method(self):
|
|
return self.x
|
|
'''
|
|
)
|
|
|
|
|
|
def test_repo_helper() -> None:
|
|
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
|
path_to_file = project_root / "main.py"
|
|
path_to_utils = project_root / "utils.py"
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="fetch_and_process_data",
|
|
file_path=str(path_to_file),
|
|
parents=[],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
import math
|
|
|
|
class DataProcessor:
|
|
|
|
def __init__(self, default_prefix: str = "PREFIX_"):
|
|
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
|
self.default_prefix = default_prefix
|
|
self.number += math.log(self.number)
|
|
|
|
def process_data(self, raw_data: str) -> str:
|
|
\"\"\"Process raw data by converting it to uppercase.\"\"\"
|
|
return raw_data.upper()
|
|
|
|
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
|
|
\"\"\"Add a prefix to the processed data.\"\"\"
|
|
return prefix + data
|
|
```
|
|
```python:{path_to_file.relative_to(project_root)}
|
|
import requests
|
|
from globals import API_URL
|
|
from utils import DataProcessor
|
|
|
|
def fetch_and_process_data():
|
|
# Use the global variable for the request
|
|
response = requests.get(API_URL)
|
|
response.raise_for_status()
|
|
|
|
raw_data = response.text
|
|
|
|
# Use code from another file (utils.py)
|
|
processor = DataProcessor()
|
|
processed = processor.process_data(raw_data)
|
|
processed = processor.add_prefix(processed)
|
|
|
|
return processed
|
|
```
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
\"\"\"A class for processing data.\"\"\"
|
|
|
|
number = 1
|
|
|
|
def __repr__(self) -> str:
|
|
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
|
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
|
```
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
|
|
def process_data(self, raw_data: str) -> str:
|
|
return raw_data.upper()
|
|
|
|
def add_prefix(self, data: str, prefix: str='PREFIX_') -> str:
|
|
return prefix + data
|
|
```
|
|
```python:{path_to_file.relative_to(project_root)}
|
|
def fetch_and_process_data():
|
|
response = requests.get(API_URL)
|
|
response.raise_for_status()
|
|
raw_data = response.text
|
|
processor = DataProcessor()
|
|
processed = processor.process_data(raw_data)
|
|
processed = processor.add_prefix(processed)
|
|
return processed
|
|
```
|
|
"""
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_repo_helper_of_helper() -> None:
|
|
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
|
path_to_file = project_root / "main.py"
|
|
path_to_utils = project_root / "utils.py"
|
|
path_to_transform_utils = project_root / "transform_utils.py"
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="fetch_and_transform_data",
|
|
file_path=str(path_to_file),
|
|
parents=[],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
import math
|
|
from transform_utils import DataTransformer
|
|
|
|
class DataProcessor:
|
|
|
|
def __init__(self, default_prefix: str = "PREFIX_"):
|
|
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
|
self.default_prefix = default_prefix
|
|
self.number += math.log(self.number)
|
|
|
|
def process_data(self, raw_data: str) -> str:
|
|
\"\"\"Process raw data by converting it to uppercase.\"\"\"
|
|
return raw_data.upper()
|
|
|
|
def transform_data(self, data: str) -> str:
|
|
\"\"\"Transform the processed data\"\"\"
|
|
return DataTransformer().transform(data)
|
|
```
|
|
```python:{path_to_file.relative_to(project_root)}
|
|
import requests
|
|
from globals import API_URL
|
|
from utils import DataProcessor
|
|
|
|
def fetch_and_transform_data():
|
|
# Use the global variable for the request
|
|
response = requests.get(API_URL)
|
|
|
|
raw_data = response.text
|
|
|
|
# Use code from another file (utils.py)
|
|
processor = DataProcessor()
|
|
processed = processor.process_data(raw_data)
|
|
transformed = processor.transform_data(processed)
|
|
|
|
return transformed
|
|
```
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
\"\"\"A class for processing data.\"\"\"
|
|
|
|
number = 1
|
|
|
|
def __repr__(self) -> str:
|
|
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
|
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
|
```
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
class DataTransformer:
|
|
def __init__(self):
|
|
self.data = None
|
|
|
|
def transform(self, data):
|
|
self.data = data
|
|
return self.data
|
|
```
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
|
|
def process_data(self, raw_data: str) -> str:
|
|
return raw_data.upper()
|
|
|
|
def transform_data(self, data: str) -> str:
|
|
return DataTransformer().transform(data)
|
|
```
|
|
```python:{path_to_file.relative_to(project_root)}
|
|
def fetch_and_transform_data():
|
|
response = requests.get(API_URL)
|
|
raw_data = response.text
|
|
processor = DataProcessor()
|
|
processed = processor.process_data(raw_data)
|
|
transformed = processor.transform_data(processed)
|
|
return transformed
|
|
```
|
|
"""
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_repo_helper_of_helper_same_class() -> None:
|
|
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
|
path_to_utils = project_root / "utils.py"
|
|
path_to_transform_utils = project_root / "transform_utils.py"
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="transform_data_own_method",
|
|
file_path=str(path_to_utils),
|
|
parents=[FunctionParent(name="DataProcessor", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
class DataTransformer:
|
|
def __init__(self):
|
|
self.data = None
|
|
|
|
def transform_using_own_method(self, data):
|
|
return self.transform(data)
|
|
```
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
import math
|
|
from transform_utils import DataTransformer
|
|
|
|
class DataProcessor:
|
|
|
|
def __init__(self, default_prefix: str = "PREFIX_"):
|
|
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
|
self.default_prefix = default_prefix
|
|
self.number += math.log(self.number)
|
|
|
|
def transform_data_own_method(self, data: str) -> str:
|
|
\"\"\"Transform the processed data using own method\"\"\"
|
|
return DataTransformer().transform_using_own_method(data)
|
|
```
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
class DataTransformer:
|
|
|
|
def transform(self, data):
|
|
self.data = data
|
|
return self.data
|
|
```
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
\"\"\"A class for processing data.\"\"\"
|
|
|
|
number = 1
|
|
|
|
def __repr__(self) -> str:
|
|
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
|
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
|
```
|
|
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:transform_utils.py
|
|
class DataTransformer:
|
|
|
|
def transform_using_own_method(self, data):
|
|
return self.transform(data)
|
|
```
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
|
|
def transform_data_own_method(self, data: str) -> str:
|
|
return DataTransformer().transform_using_own_method(data)
|
|
```
|
|
"""
|
|
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_repo_helper_of_helper_same_file() -> None:
|
|
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
|
path_to_utils = project_root / "utils.py"
|
|
path_to_transform_utils = project_root / "transform_utils.py"
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="transform_data_same_file_function",
|
|
file_path=str(path_to_utils),
|
|
parents=[FunctionParent(name="DataProcessor", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
class DataTransformer:
|
|
def __init__(self):
|
|
self.data = None
|
|
|
|
def transform_using_same_file_function(self, data):
|
|
return update_data(data)
|
|
```
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
import math
|
|
from transform_utils import DataTransformer
|
|
|
|
class DataProcessor:
|
|
|
|
def __init__(self, default_prefix: str = "PREFIX_"):
|
|
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
|
self.default_prefix = default_prefix
|
|
self.number += math.log(self.number)
|
|
|
|
def transform_data_same_file_function(self, data: str) -> str:
|
|
\"\"\"Transform the processed data using a function from the same file\"\"\"
|
|
return DataTransformer().transform_using_same_file_function(data)
|
|
```
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
def update_data(data):
|
|
return data + " updated"
|
|
```
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
\"\"\"A class for processing data.\"\"\"
|
|
|
|
number = 1
|
|
|
|
def __repr__(self) -> str:
|
|
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
|
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
|
```
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:transform_utils.py
|
|
class DataTransformer:
|
|
|
|
def transform_using_same_file_function(self, data):
|
|
return update_data(data)
|
|
```
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
|
|
def transform_data_same_file_function(self, data: str) -> str:
|
|
return DataTransformer().transform_using_same_file_function(data)
|
|
```
|
|
"""
|
|
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_repo_helper_all_same_file() -> None:
|
|
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
|
path_to_transform_utils = project_root / "transform_utils.py"
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="transform_data_all_same_file",
|
|
file_path=str(path_to_transform_utils),
|
|
parents=[FunctionParent(name="DataTransformer", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
class DataTransformer:
|
|
def __init__(self):
|
|
self.data = None
|
|
|
|
def transform_using_own_method(self, data):
|
|
return self.transform(data)
|
|
|
|
def transform_data_all_same_file(self, data):
|
|
new_data = update_data(data)
|
|
return self.transform_using_own_method(new_data)
|
|
|
|
|
|
def update_data(data):
|
|
return data + " updated"
|
|
```
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
class DataTransformer:
|
|
|
|
def transform(self, data):
|
|
self.data = data
|
|
return self.data
|
|
```
|
|
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
class DataTransformer:
|
|
|
|
def transform_using_own_method(self, data):
|
|
return self.transform(data)
|
|
|
|
def transform_data_all_same_file(self, data):
|
|
new_data = update_data(data)
|
|
return self.transform_using_own_method(new_data)
|
|
|
|
def update_data(data):
|
|
return data + ' updated'
|
|
```
|
|
"""
|
|
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_repo_helper_circular_dependency() -> None:
|
|
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
|
path_to_utils = project_root / "utils.py"
|
|
path_to_transform_utils = project_root / "transform_utils.py"
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="circular_dependency",
|
|
file_path=str(path_to_transform_utils),
|
|
parents=[FunctionParent(name="DataTransformer", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
import math
|
|
from transform_utils import DataTransformer
|
|
|
|
class DataProcessor:
|
|
|
|
def __init__(self, default_prefix: str = "PREFIX_"):
|
|
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
|
self.default_prefix = default_prefix
|
|
self.number += math.log(self.number)
|
|
|
|
def circular_dependency(self, data: str) -> str:
|
|
\"\"\"Test circular dependency\"\"\"
|
|
return DataTransformer().circular_dependency(data)
|
|
```
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
from code_to_optimize.code_directories.retriever.utils import DataProcessor
|
|
|
|
class DataTransformer:
|
|
def __init__(self):
|
|
self.data = None
|
|
|
|
def circular_dependency(self, data):
|
|
return DataProcessor().circular_dependency(data)
|
|
```
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{path_to_utils.relative_to(project_root)}
|
|
class DataProcessor:
|
|
\"\"\"A class for processing data.\"\"\"
|
|
|
|
number = 1
|
|
|
|
def __repr__(self) -> str:
|
|
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
|
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
|
```
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
class DataTransformer:
|
|
def __init__(self):
|
|
self.data = None
|
|
```
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:utils.py
|
|
class DataProcessor:
|
|
|
|
def circular_dependency(self, data: str) -> str:
|
|
return DataTransformer().circular_dependency(data)
|
|
```
|
|
```python:{path_to_transform_utils.relative_to(project_root)}
|
|
class DataTransformer:
|
|
|
|
def circular_dependency(self, data):
|
|
return DataProcessor().circular_dependency(data)
|
|
```
|
|
"""
|
|
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_indirect_init_helper(tmp_path: Path) -> None:
|
|
code = """
|
|
class MyClass:
|
|
def __init__(self):
|
|
self.x = 1
|
|
self.y = outside_method()
|
|
def target_method(self):
|
|
return self.x + self.y
|
|
|
|
def outside_method():
|
|
return 1
|
|
"""
|
|
# Create a temporary Python file using pytest's tmp_path fixture
|
|
file_path = tmp_path / "test_code.py"
|
|
file_path.write_text(code, encoding="utf-8")
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
expected_read_write_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
def __init__(self):
|
|
self.x = 1
|
|
self.y = outside_method()
|
|
def target_method(self):
|
|
return self.x + self.y
|
|
```
|
|
"""
|
|
expected_read_only_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
def outside_method():
|
|
return 1
|
|
```
|
|
"""
|
|
expected_hashing_context = f"""
|
|
```python:{file_path.relative_to(opt.args.project_root)}
|
|
class MyClass:
|
|
|
|
def target_method(self):
|
|
return self.x + self.y
|
|
```
|
|
"""
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_direct_module_import() -> None:
|
|
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
|
path_to_main = project_root / "main.py"
|
|
path_to_fto = project_root / "import_test.py"
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="function_to_optimize",
|
|
file_path=str(path_to_fto),
|
|
parents=[],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
expected_read_only_context = """
|
|
```python:utils.py
|
|
import math
|
|
from transform_utils import DataTransformer
|
|
|
|
class DataProcessor:
|
|
\"\"\"A class for processing data.\"\"\"
|
|
|
|
number = 1
|
|
|
|
def __init__(self, default_prefix: str = "PREFIX_"):
|
|
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
|
|
self.default_prefix = default_prefix
|
|
self.number += math.log(self.number)
|
|
|
|
def __repr__(self) -> str:
|
|
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
|
return f"DataProcessor(default_prefix={self.default_prefix!r})"
|
|
|
|
def process_data(self, raw_data: str) -> str:
|
|
\"\"\"Process raw data by converting it to uppercase.\"\"\"
|
|
return raw_data.upper()
|
|
|
|
def transform_data(self, data: str) -> str:
|
|
\"\"\"Transform the processed data\"\"\"
|
|
return DataTransformer().transform(data)
|
|
```"""
|
|
expected_hashing_context = """
|
|
```python:main.py
|
|
def fetch_and_transform_data():
|
|
response = requests.get(API_URL)
|
|
raw_data = response.text
|
|
processor = DataProcessor()
|
|
processed = processor.process_data(raw_data)
|
|
transformed = processor.transform_data(processed)
|
|
return transformed
|
|
```
|
|
```python:import_test.py
|
|
def function_to_optimize():
|
|
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
|
|
```
|
|
"""
|
|
expected_read_write_context = f"""
|
|
```python:{path_to_main.relative_to(project_root)}
|
|
import requests
|
|
from globals import API_URL
|
|
from utils import DataProcessor
|
|
|
|
def fetch_and_transform_data():
|
|
# Use the global variable for the request
|
|
response = requests.get(API_URL)
|
|
|
|
raw_data = response.text
|
|
|
|
# Use code from another file (utils.py)
|
|
processor = DataProcessor()
|
|
processed = processor.process_data(raw_data)
|
|
transformed = processor.transform_data(processed)
|
|
|
|
return transformed
|
|
```
|
|
```python:{path_to_fto.relative_to(project_root)}
|
|
import code_to_optimize.code_directories.retriever.main
|
|
|
|
def function_to_optimize():
|
|
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
|
|
```
|
|
"""
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_module_import_optimization() -> None:
|
|
main_code = """
|
|
import utility_module
|
|
|
|
class Calculator:
|
|
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
|
|
# This is where we use the imported module
|
|
self.precision = utility_module.select_precision(precision, fallback_precision)
|
|
self.mode = mode
|
|
|
|
# Using variables from the utility module
|
|
self.backend = utility_module.CALCULATION_BACKEND
|
|
self.system = utility_module.SYSTEM_TYPE
|
|
self.default_precision = utility_module.DEFAULT_PRECISION
|
|
|
|
def add(self, a, b):
|
|
return a + b
|
|
|
|
def subtract(self, a, b):
|
|
return a - b
|
|
|
|
def calculate(self, operation, x, y):
|
|
if operation == "add":
|
|
return self.add(x, y)
|
|
elif operation == "subtract":
|
|
return self.subtract(x, y)
|
|
else:
|
|
return None
|
|
"""
|
|
|
|
utility_module_code = """
|
|
import sys
|
|
import platform
|
|
import logging
|
|
|
|
DEFAULT_PRECISION = "medium"
|
|
DEFAULT_MODE = "standard"
|
|
|
|
# Try-except block with variable definitions
|
|
try:
|
|
import numpy as np
|
|
# Used variable in try block
|
|
CALCULATION_BACKEND = "numpy"
|
|
# Unused variable in try block
|
|
VECTOR_DIMENSIONS = 3
|
|
except ImportError:
|
|
# Used variable in except block
|
|
CALCULATION_BACKEND = "python"
|
|
# Unused variable in except block
|
|
FALLBACK_WARNING = "NumPy not available, using slower Python implementation"
|
|
|
|
# Nested if-else with variable definitions
|
|
if sys.platform.startswith('win'):
|
|
# Used variable in outer if
|
|
SYSTEM_TYPE = "windows"
|
|
if platform.architecture()[0] == '64bit':
|
|
# Unused variable in nested if
|
|
MEMORY_MODEL = "x64"
|
|
else:
|
|
# Unused variable in nested else
|
|
MEMORY_MODEL = "x86"
|
|
elif sys.platform.startswith('linux'):
|
|
# Used variable in outer elif
|
|
SYSTEM_TYPE = "linux"
|
|
# Unused variable in outer elif
|
|
KERNEL_VERSION = platform.release()
|
|
else:
|
|
# Used variable in outer else
|
|
SYSTEM_TYPE = "other"
|
|
# Unused variable in outer else
|
|
UNKNOWN_SYSTEM_MSG = "Running on an unrecognized platform"
|
|
|
|
# Function that will be used in the main code
|
|
def select_precision(precision, fallback_precision):
|
|
if precision is None:
|
|
return fallback_precision or DEFAULT_PRECISION
|
|
|
|
# Using the variables defined above
|
|
if CALCULATION_BACKEND == "numpy":
|
|
# Higher precision available with NumPy
|
|
precision_options = ["low", "medium", "high", "ultra"]
|
|
else:
|
|
# Limited precision without NumPy
|
|
precision_options = ["low", "medium", "high"]
|
|
|
|
if isinstance(precision, str):
|
|
if precision.lower() not in precision_options:
|
|
if fallback_precision:
|
|
return fallback_precision
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
return precision.lower()
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
|
|
# Function that won't be used
|
|
def get_system_details():
|
|
return {
|
|
"system": SYSTEM_TYPE,
|
|
"backend": CALCULATION_BACKEND,
|
|
"default_precision": DEFAULT_PRECISION,
|
|
"python_version": sys.version
|
|
}
|
|
"""
|
|
|
|
# Create a temporary directory for the test
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# Set up the package structure
|
|
package_dir = Path(temp_dir) / "package"
|
|
package_dir.mkdir()
|
|
|
|
# Create the __init__.py file
|
|
with open(package_dir / "__init__.py", "w") as init_file:
|
|
init_file.write("")
|
|
|
|
# Write the utility_module.py file
|
|
with open(package_dir / "utility_module.py", "w") as utility_file:
|
|
utility_file.write(utility_module_code)
|
|
utility_file.flush()
|
|
|
|
# Write the main code file
|
|
main_file_path = package_dir / "main_module.py"
|
|
with open(main_file_path, "w") as main_file:
|
|
main_file.write(main_code)
|
|
main_file.flush()
|
|
|
|
# Set up the optimizer
|
|
file_path = main_file_path.resolve()
|
|
project_root = package_dir.resolve()
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=project_root,
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
|
|
# Define the function to optimize
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="calculate",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="Calculator", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
# Get the code optimization context
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
hashing_context = code_ctx.hashing_code_context
|
|
# The expected contexts
|
|
# Resolve both paths to handle symlink issues on macOS
|
|
relative_path = file_path.relative_to(project_root)
|
|
expected_read_write_context = f"""
|
|
```python:{main_file_path.resolve().relative_to(opt.args.project_root.resolve())}
|
|
import utility_module
|
|
|
|
class Calculator:
|
|
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
|
|
# This is where we use the imported module
|
|
self.precision = utility_module.select_precision(precision, fallback_precision)
|
|
self.mode = mode
|
|
|
|
# Using variables from the utility module
|
|
self.backend = utility_module.CALCULATION_BACKEND
|
|
self.system = utility_module.SYSTEM_TYPE
|
|
self.default_precision = utility_module.DEFAULT_PRECISION
|
|
|
|
def add(self, a, b):
|
|
return a + b
|
|
|
|
def subtract(self, a, b):
|
|
return a - b
|
|
|
|
def calculate(self, operation, x, y):
|
|
if operation == "add":
|
|
return self.add(x, y)
|
|
elif operation == "subtract":
|
|
return self.subtract(x, y)
|
|
else:
|
|
return None
|
|
```
|
|
"""
|
|
expected_read_only_context = """
|
|
```python:utility_module.py
|
|
DEFAULT_PRECISION = "medium"
|
|
|
|
# Try-except block with variable definitions
|
|
try:
|
|
# Used variable in try block
|
|
CALCULATION_BACKEND = "numpy"
|
|
except ImportError:
|
|
# Used variable in except block
|
|
CALCULATION_BACKEND = "python"
|
|
|
|
# Function that will be used in the main code
|
|
def select_precision(precision, fallback_precision):
|
|
if precision is None:
|
|
return fallback_precision or DEFAULT_PRECISION
|
|
|
|
# Using the variables defined above
|
|
if CALCULATION_BACKEND == "numpy":
|
|
# Higher precision available with NumPy
|
|
precision_options = ["low", "medium", "high", "ultra"]
|
|
else:
|
|
# Limited precision without NumPy
|
|
precision_options = ["low", "medium", "high"]
|
|
|
|
if isinstance(precision, str):
|
|
if precision.lower() not in precision_options:
|
|
if fallback_precision:
|
|
return fallback_precision
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
return precision.lower()
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
```
|
|
"""
|
|
expected_hashing_context = """
|
|
```python:main_module.py
|
|
class Calculator:
|
|
|
|
def add(self, a, b):
|
|
return a + b
|
|
|
|
def subtract(self, a, b):
|
|
return a - b
|
|
|
|
def calculate(self, operation, x, y):
|
|
if operation == 'add':
|
|
return self.add(x, y)
|
|
elif operation == 'subtract':
|
|
return self.subtract(x, y)
|
|
else:
|
|
return None
|
|
```
|
|
"""
|
|
# Verify the contexts match the expected values
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
assert hashing_context.strip() == expected_hashing_context.strip()
|
|
|
|
|
|
def test_module_import_init_fto() -> None:
|
|
main_code = """
|
|
import utility_module
|
|
|
|
class Calculator:
|
|
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
|
|
# This is where we use the imported module
|
|
self.precision = utility_module.select_precision(precision, fallback_precision)
|
|
self.mode = mode
|
|
|
|
# Using variables from the utility module
|
|
self.backend = utility_module.CALCULATION_BACKEND
|
|
self.system = utility_module.SYSTEM_TYPE
|
|
self.default_precision = utility_module.DEFAULT_PRECISION
|
|
|
|
def add(self, a, b):
|
|
return a + b
|
|
|
|
def subtract(self, a, b):
|
|
return a - b
|
|
|
|
def calculate(self, operation, x, y):
|
|
if operation == "add":
|
|
return self.add(x, y)
|
|
elif operation == "subtract":
|
|
return self.subtract(x, y)
|
|
else:
|
|
return None
|
|
"""
|
|
|
|
utility_module_code = """
|
|
import sys
|
|
import platform
|
|
import logging
|
|
|
|
DEFAULT_PRECISION = "medium"
|
|
DEFAULT_MODE = "standard"
|
|
|
|
# Try-except block with variable definitions
|
|
try:
|
|
import numpy as np
|
|
# Used variable in try block
|
|
CALCULATION_BACKEND = "numpy"
|
|
# Unused variable in try block
|
|
VECTOR_DIMENSIONS = 3
|
|
except ImportError:
|
|
# Used variable in except block
|
|
CALCULATION_BACKEND = "python"
|
|
# Unused variable in except block
|
|
FALLBACK_WARNING = "NumPy not available, using slower Python implementation"
|
|
|
|
# Nested if-else with variable definitions
|
|
if sys.platform.startswith('win'):
|
|
# Used variable in outer if
|
|
SYSTEM_TYPE = "windows"
|
|
if platform.architecture()[0] == '64bit':
|
|
# Unused variable in nested if
|
|
MEMORY_MODEL = "x64"
|
|
else:
|
|
# Unused variable in nested else
|
|
MEMORY_MODEL = "x86"
|
|
elif sys.platform.startswith('linux'):
|
|
# Used variable in outer elif
|
|
SYSTEM_TYPE = "linux"
|
|
# Unused variable in outer elif
|
|
KERNEL_VERSION = platform.release()
|
|
else:
|
|
# Used variable in outer else
|
|
SYSTEM_TYPE = "other"
|
|
# Unused variable in outer else
|
|
UNKNOWN_SYSTEM_MSG = "Running on an unrecognized platform"
|
|
|
|
# Function that will be used in the main code
|
|
def select_precision(precision, fallback_precision):
|
|
if precision is None:
|
|
return fallback_precision or DEFAULT_PRECISION
|
|
|
|
# Using the variables defined above
|
|
if CALCULATION_BACKEND == "numpy":
|
|
# Higher precision available with NumPy
|
|
precision_options = ["low", "medium", "high", "ultra"]
|
|
else:
|
|
# Limited precision without NumPy
|
|
precision_options = ["low", "medium", "high"]
|
|
|
|
if isinstance(precision, str):
|
|
if precision.lower() not in precision_options:
|
|
if fallback_precision:
|
|
return fallback_precision
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
return precision.lower()
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
|
|
# Function that won't be used
|
|
def get_system_details():
|
|
return {
|
|
"system": SYSTEM_TYPE,
|
|
"backend": CALCULATION_BACKEND,
|
|
"default_precision": DEFAULT_PRECISION,
|
|
"python_version": sys.version
|
|
}
|
|
"""
|
|
|
|
# Create a temporary directory for the test
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# Set up the package structure
|
|
package_dir = Path(temp_dir) / "package"
|
|
package_dir.mkdir()
|
|
|
|
# Create the __init__.py file
|
|
with open(package_dir / "__init__.py", "w") as init_file:
|
|
init_file.write("")
|
|
|
|
# Write the utility_module.py file
|
|
with open(package_dir / "utility_module.py", "w") as utility_file:
|
|
utility_file.write(utility_module_code)
|
|
utility_file.flush()
|
|
|
|
# Write the main code file
|
|
main_file_path = package_dir / "main_module.py"
|
|
with open(main_file_path, "w") as main_file:
|
|
main_file.write(main_code)
|
|
main_file.flush()
|
|
|
|
# Set up the optimizer
|
|
file_path = main_file_path.resolve()
|
|
project_root = package_dir.resolve()
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=project_root,
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
|
|
# Define the function to optimize
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="__init__",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="Calculator", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
# Get the code optimization context
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
|
# The expected contexts
|
|
relative_path = file_path.relative_to(project_root)
|
|
expected_read_write_context = f"""
|
|
```python:utility_module.py
|
|
DEFAULT_PRECISION = "medium"
|
|
|
|
# Try-except block with variable definitions
|
|
try:
|
|
# Used variable in try block
|
|
CALCULATION_BACKEND = "numpy"
|
|
except ImportError:
|
|
# Used variable in except block
|
|
CALCULATION_BACKEND = "python"
|
|
|
|
# Function that will be used in the main code
|
|
def select_precision(precision, fallback_precision):
|
|
if precision is None:
|
|
return fallback_precision or DEFAULT_PRECISION
|
|
|
|
# Using the variables defined above
|
|
if CALCULATION_BACKEND == "numpy":
|
|
# Higher precision available with NumPy
|
|
precision_options = ["low", "medium", "high", "ultra"]
|
|
else:
|
|
# Limited precision without NumPy
|
|
precision_options = ["low", "medium", "high"]
|
|
|
|
if isinstance(precision, str):
|
|
if precision.lower() not in precision_options:
|
|
if fallback_precision:
|
|
return fallback_precision
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
return precision.lower()
|
|
else:
|
|
return DEFAULT_PRECISION
|
|
```
|
|
```python:{main_file_path.resolve().relative_to(opt.args.project_root.resolve())}
|
|
import utility_module
|
|
|
|
class Calculator:
|
|
def __init__(self, precision="high", fallback_precision=None, mode="standard"):
|
|
# This is where we use the imported module
|
|
self.precision = utility_module.select_precision(precision, fallback_precision)
|
|
self.mode = mode
|
|
|
|
# Using variables from the utility module
|
|
self.backend = utility_module.CALCULATION_BACKEND
|
|
self.system = utility_module.SYSTEM_TYPE
|
|
self.default_precision = utility_module.DEFAULT_PRECISION
|
|
```
|
|
"""
|
|
expected_read_only_context = """
|
|
```python:utility_module.py
|
|
DEFAULT_PRECISION = "medium"
|
|
|
|
# Try-except block with variable definitions
|
|
try:
|
|
# Used variable in try block
|
|
CALCULATION_BACKEND = "numpy"
|
|
except ImportError:
|
|
# Used variable in except block
|
|
CALCULATION_BACKEND = "python"
|
|
```
|
|
"""
|
|
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
|
|
assert read_only_context.strip() == expected_read_only_context.strip()
|
|
|
|
|
|
def test_hashing_code_context_removes_imports_docstrings_and_init(tmp_path: Path) -> None:
|
|
"""Test that hashing context removes imports, docstrings, and __init__ methods properly."""
|
|
code = '''
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
class MyClass:
|
|
"""A class with a docstring."""
|
|
def __init__(self, value):
|
|
"""Initialize with a value."""
|
|
self.value = value
|
|
|
|
def target_method(self):
|
|
"""Target method with docstring."""
|
|
result = self.helper_method()
|
|
helper_cls = HelperClass()
|
|
data = helper_cls.process_data()
|
|
return self.value * 2
|
|
|
|
def helper_method(self):
|
|
"""Helper method with docstring."""
|
|
return self.value + 1
|
|
|
|
class HelperClass:
|
|
"""Helper class docstring."""
|
|
def __init__(self):
|
|
"""Helper init method."""
|
|
self.data = "test"
|
|
|
|
def process_data(self):
|
|
"""Process data method."""
|
|
return self.data.upper()
|
|
|
|
def standalone_function():
|
|
"""Standalone function."""
|
|
return "standalone"
|
|
'''
|
|
# Create a temporary Python file using pytest's tmp_path fixture
|
|
file_path = tmp_path / "test_code.py"
|
|
file_path.write_text(code, encoding="utf-8")
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
# Expected behavior based on current implementation:
|
|
# - Should not contain imports
|
|
# - Should remove docstrings from target functions (but currently doesn't - this is a bug)
|
|
# - Should not contain __init__ methods
|
|
# - Should contain target function and helper methods that are actually called
|
|
# - Should be formatted as markdown
|
|
|
|
# Test that it's formatted as markdown
|
|
assert hashing_context.startswith("```python:")
|
|
assert hashing_context.endswith("```")
|
|
|
|
# Test basic structure requirements
|
|
assert "import" not in hashing_context # Should not contain imports
|
|
assert "__init__" not in hashing_context # Should not contain __init__ methods
|
|
assert "target_method" in hashing_context # Should contain target function
|
|
assert "standalone_function" not in hashing_context # Should not contain unused functions
|
|
|
|
# Test that helper functions are included when they're called
|
|
assert "helper_method" in hashing_context # Should contain called helper method
|
|
assert "process_data" in hashing_context # Should contain called helper method
|
|
|
|
# Test for docstring removal (this should pass when implementation is fixed)
|
|
# Currently this will fail because docstrings are not being removed properly
|
|
assert '"""Target method with docstring."""' not in hashing_context, (
|
|
"Docstrings should be removed from target functions"
|
|
)
|
|
assert '"""Helper method with docstring."""' not in hashing_context, (
|
|
"Docstrings should be removed from helper functions"
|
|
)
|
|
assert '"""Process data method."""' not in hashing_context, "Docstrings should be removed from helper class methods"
|
|
|
|
|
|
def test_hashing_code_context_with_nested_classes(tmp_path: Path) -> None:
|
|
"""Test that hashing context handles nested classes properly (should exclude them)."""
|
|
code = '''
|
|
class OuterClass:
|
|
"""Outer class docstring."""
|
|
def __init__(self):
|
|
"""Outer init."""
|
|
self.value = 1
|
|
|
|
def target_method(self):
|
|
"""Target method."""
|
|
return self.NestedClass().nested_method()
|
|
|
|
class NestedClass:
|
|
"""Nested class - should be excluded."""
|
|
def __init__(self):
|
|
self.nested_value = 2
|
|
|
|
def nested_method(self):
|
|
return self.nested_value
|
|
'''
|
|
# Create a temporary Python file using pytest's tmp_path fixture
|
|
file_path = tmp_path / "test_code.py"
|
|
file_path.write_text(code, encoding="utf-8")
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="OuterClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
# Test basic requirements
|
|
assert hashing_context.startswith("```python:")
|
|
assert hashing_context.endswith("```")
|
|
assert "target_method" in hashing_context
|
|
assert "__init__" not in hashing_context # Should not contain __init__ methods
|
|
|
|
# Verify nested classes are excluded from the hashing context
|
|
# The prune_cst_for_code_hashing function should not recurse into nested classes
|
|
assert "class NestedClass:" not in hashing_context # Nested class definition should not be present
|
|
|
|
# The target method will reference NestedClass, but the actual nested class definition should not be included
|
|
# The call to self.NestedClass().nested_method() should be in the target method but the nested class itself excluded
|
|
target_method_call_present = "self.NestedClass().nested_method()" in hashing_context
|
|
assert target_method_call_present, "The target method should contain the call to nested class"
|
|
|
|
# But the actual nested method definition should not be present
|
|
nested_method_definition_present = "def nested_method(self):" in hashing_context
|
|
assert not nested_method_definition_present, "Nested method definition should not be present in hashing context"
|
|
|
|
|
|
def test_hashing_code_context_hash_consistency(tmp_path: Path) -> None:
|
|
"""Test that the same code produces the same hash."""
|
|
code = """
|
|
class TestClass:
|
|
def target_method(self):
|
|
return "test"
|
|
"""
|
|
# Create a temporary Python file using pytest's tmp_path fixture
|
|
file_path = tmp_path / "test_code.py"
|
|
file_path.write_text(code, encoding="utf-8")
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="TestClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
# Generate context twice
|
|
code_ctx1 = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
code_ctx2 = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
|
|
# Hash should be consistent
|
|
assert code_ctx1.hashing_code_context_hash == code_ctx2.hashing_code_context_hash
|
|
assert code_ctx1.hashing_code_context == code_ctx2.hashing_code_context
|
|
|
|
# Hash should be valid SHA256
|
|
import hashlib
|
|
|
|
expected_hash = hashlib.sha256(code_ctx1.hashing_code_context.encode("utf-8")).hexdigest()
|
|
assert code_ctx1.hashing_code_context_hash == expected_hash
|
|
|
|
|
|
def test_hashing_code_context_different_code_different_hash(tmp_path: Path) -> None:
|
|
"""Test that different code produces different hashes."""
|
|
code1 = """
|
|
class TestClass:
|
|
def target_method(self):
|
|
return "test1"
|
|
"""
|
|
code2 = """
|
|
class TestClass:
|
|
def target_method(self):
|
|
return "test2"
|
|
"""
|
|
|
|
# Create two temporary Python files using pytest's tmp_path fixture
|
|
file_path1 = tmp_path / "test_code1.py"
|
|
file_path2 = tmp_path / "test_code2.py"
|
|
file_path1.write_text(code1, encoding="utf-8")
|
|
file_path2.write_text(code2, encoding="utf-8")
|
|
|
|
opt1 = Optimizer(
|
|
Namespace(
|
|
project_root=file_path1.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
opt2 = Optimizer(
|
|
Namespace(
|
|
project_root=file_path2.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
|
|
function_to_optimize1 = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path1,
|
|
parents=[FunctionParent(name="TestClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
function_to_optimize2 = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=file_path2,
|
|
parents=[FunctionParent(name="TestClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx1 = get_code_optimization_context(function_to_optimize1, opt1.args.project_root)
|
|
code_ctx2 = get_code_optimization_context(function_to_optimize2, opt2.args.project_root)
|
|
|
|
# Different code should produce different hashes
|
|
assert code_ctx1.hashing_code_context_hash != code_ctx2.hashing_code_context_hash
|
|
assert code_ctx1.hashing_code_context != code_ctx2.hashing_code_context
|
|
|
|
|
|
def test_hashing_code_context_format_is_markdown(tmp_path: Path) -> None:
|
|
"""Test that hashing context is formatted as markdown."""
|
|
code = """
|
|
class SimpleClass:
|
|
def simple_method(self):
|
|
return 42
|
|
"""
|
|
# Create a temporary Python file using pytest's tmp_path fixture
|
|
file_path = tmp_path / "test_code.py"
|
|
file_path.write_text(code, encoding="utf-8")
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="simple_method",
|
|
file_path=file_path,
|
|
parents=[FunctionParent(name="SimpleClass", type="ClassDef")],
|
|
starting_line=None,
|
|
ending_line=None,
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
hashing_context = code_ctx.hashing_code_context
|
|
|
|
# Should be formatted as markdown code block
|
|
assert hashing_context.startswith("```python:")
|
|
assert hashing_context.endswith("```")
|
|
|
|
# Should contain the relative file path in the markdown header
|
|
relative_path = file_path.relative_to(opt.args.project_root)
|
|
assert str(relative_path) in hashing_context
|
|
|
|
# Should contain the actual code between the markdown markers
|
|
lines = hashing_context.strip().split("\n")
|
|
assert lines[0].startswith("```python:")
|
|
assert lines[-1] == "```"
|
|
|
|
# Code should be between the markers
|
|
code_lines = lines[1:-1]
|
|
code_content = "\n".join(code_lines)
|
|
assert "class SimpleClass:" in code_content
|
|
assert "def simple_method(self):" in code_content
|
|
assert "return 42" in code_content
|
|
|
|
|
|
# This shouldn't happen as we are now using a scoped optimization context, but keep it just in case
|
|
def test_circular_deps():
|
|
path_to_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "circular_deps"
|
|
file_abs_path = path_to_root / "api_client.py"
|
|
optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8")
|
|
content = Path(file_abs_path).read_text(encoding="utf-8")
|
|
new_code = replace_functions_and_add_imports(
|
|
source_code=add_global_assignments(optimized_code, content),
|
|
function_names=["ApiClient.get_console_url"],
|
|
optimized_code=optimized_code,
|
|
module_abspath=Path(file_abs_path),
|
|
preexisting_objects={
|
|
("ApiClient", ()),
|
|
("get_console_url", (FunctionParent(name="ApiClient", type="ClassDef"),)),
|
|
},
|
|
project_root_path=Path(path_to_root),
|
|
)
|
|
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
|
|
|
|
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
|
|
|
|
|
|
def test_global_assignment_collector_with_async_function():
|
|
"""Test GlobalAssignmentCollector correctly identifies global assignments outside async functions."""
|
|
import libcst as cst
|
|
|
|
source_code = """
|
|
# Global assignment
|
|
GLOBAL_VAR = "global_value"
|
|
OTHER_GLOBAL = 42
|
|
|
|
async def async_function():
|
|
# This should not be collected (inside async function)
|
|
local_var = "local_value"
|
|
INNER_ASSIGNMENT = "should_not_be_global"
|
|
return local_var
|
|
|
|
# Another global assignment
|
|
ANOTHER_GLOBAL = "another_global"
|
|
"""
|
|
|
|
tree = cst.parse_module(source_code)
|
|
collector = GlobalAssignmentCollector()
|
|
tree.visit(collector)
|
|
|
|
# Should collect global assignments but not the ones inside async function
|
|
assert len(collector.assignments) == 3
|
|
assert "GLOBAL_VAR" in collector.assignments
|
|
assert "OTHER_GLOBAL" in collector.assignments
|
|
assert "ANOTHER_GLOBAL" in collector.assignments
|
|
|
|
# Should not collect assignments from inside async function
|
|
assert "local_var" not in collector.assignments
|
|
assert "INNER_ASSIGNMENT" not in collector.assignments
|
|
|
|
# Verify assignment order
|
|
expected_order = ["GLOBAL_VAR", "OTHER_GLOBAL", "ANOTHER_GLOBAL"]
|
|
assert collector.assignment_order == expected_order
|
|
|
|
|
|
def test_global_assignment_collector_nested_async_functions():
|
|
"""Test GlobalAssignmentCollector handles nested async functions correctly."""
|
|
import libcst as cst
|
|
|
|
source_code = """
|
|
# Global assignment
|
|
CONFIG = {"key": "value"}
|
|
|
|
def sync_function():
|
|
# Inside sync function - should not be collected
|
|
sync_local = "sync"
|
|
|
|
async def nested_async():
|
|
# Inside nested async function - should not be collected
|
|
nested_var = "nested"
|
|
return nested_var
|
|
|
|
return sync_local
|
|
|
|
async def async_function():
|
|
# Inside async function - should not be collected
|
|
async_local = "async"
|
|
|
|
def nested_sync():
|
|
# Inside nested function - should not be collected
|
|
deeply_nested = "deep"
|
|
return deeply_nested
|
|
|
|
return async_local
|
|
|
|
# Another global assignment
|
|
FINAL_GLOBAL = "final"
|
|
"""
|
|
|
|
tree = cst.parse_module(source_code)
|
|
collector = GlobalAssignmentCollector()
|
|
tree.visit(collector)
|
|
|
|
# Should only collect global-level assignments
|
|
assert len(collector.assignments) == 2
|
|
assert "CONFIG" in collector.assignments
|
|
assert "FINAL_GLOBAL" in collector.assignments
|
|
|
|
# Should not collect any assignments from inside functions
|
|
assert "sync_local" not in collector.assignments
|
|
assert "nested_var" not in collector.assignments
|
|
assert "async_local" not in collector.assignments
|
|
assert "deeply_nested" not in collector.assignments
|
|
|
|
|
|
def test_global_assignment_collector_mixed_async_sync_with_classes():
|
|
"""Test GlobalAssignmentCollector with async functions, sync functions, and classes."""
|
|
import libcst as cst
|
|
|
|
source_code = """
|
|
# Global assignments
|
|
GLOBAL_CONSTANT = "constant"
|
|
|
|
class TestClass:
|
|
# Class-level assignment - should not be collected
|
|
class_var = "class_value"
|
|
|
|
def sync_method(self):
|
|
# Method assignment - should not be collected
|
|
method_var = "method"
|
|
return method_var
|
|
|
|
async def async_method(self):
|
|
# Async method assignment - should not be collected
|
|
async_method_var = "async_method"
|
|
return async_method_var
|
|
|
|
def sync_function():
|
|
# Function assignment - should not be collected
|
|
func_var = "function"
|
|
return func_var
|
|
|
|
async def async_function():
|
|
# Async function assignment - should not be collected
|
|
async_func_var = "async_function"
|
|
return async_func_var
|
|
|
|
# More global assignments
|
|
ANOTHER_CONSTANT = 100
|
|
FINAL_ASSIGNMENT = {"data": "value"}
|
|
"""
|
|
|
|
tree = cst.parse_module(source_code)
|
|
collector = GlobalAssignmentCollector()
|
|
tree.visit(collector)
|
|
|
|
# Should only collect global-level assignments
|
|
assert len(collector.assignments) == 3
|
|
assert "GLOBAL_CONSTANT" in collector.assignments
|
|
assert "ANOTHER_CONSTANT" in collector.assignments
|
|
assert "FINAL_ASSIGNMENT" in collector.assignments
|
|
|
|
# Should not collect assignments from inside any scoped blocks
|
|
assert "class_var" not in collector.assignments
|
|
assert "method_var" not in collector.assignments
|
|
assert "async_method_var" not in collector.assignments
|
|
assert "func_var" not in collector.assignments
|
|
assert "async_func_var" not in collector.assignments
|
|
|
|
# Verify correct order
|
|
expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"]
|
|
assert collector.assignment_order == expected_order
|
|
|
|
|
|
def test_global_assignment_collector_annotated_assignments():
|
|
"""Test GlobalAssignmentCollector correctly handles annotated assignments (AnnAssign)."""
|
|
import libcst as cst
|
|
|
|
source_code = """
|
|
# Regular global assignment
|
|
REGULAR_VAR = "regular"
|
|
|
|
# Annotated global assignments
|
|
TYPED_VAR: str = "typed"
|
|
CACHE: dict[str, int] = {}
|
|
SENTINEL: object = object()
|
|
|
|
# Annotated without value (type declaration only) - should NOT be collected
|
|
DECLARED_ONLY: int
|
|
|
|
def some_function():
|
|
# Annotated assignment inside function - should not be collected
|
|
local_typed: str = "local"
|
|
return local_typed
|
|
|
|
class SomeClass:
|
|
# Class-level annotated assignment - should not be collected
|
|
class_attr: str = "class"
|
|
|
|
# Another regular assignment
|
|
FINAL_VAR = 123
|
|
"""
|
|
|
|
tree = cst.parse_module(source_code)
|
|
collector = GlobalAssignmentCollector()
|
|
tree.visit(collector)
|
|
|
|
# Should collect both regular and annotated global assignments with values
|
|
assert len(collector.assignments) == 5
|
|
assert "REGULAR_VAR" in collector.assignments
|
|
assert "TYPED_VAR" in collector.assignments
|
|
assert "CACHE" in collector.assignments
|
|
assert "SENTINEL" in collector.assignments
|
|
assert "FINAL_VAR" in collector.assignments
|
|
|
|
# Should not collect type declarations without values
|
|
assert "DECLARED_ONLY" not in collector.assignments
|
|
|
|
# Should not collect assignments from inside functions or classes
|
|
assert "local_typed" not in collector.assignments
|
|
assert "class_attr" not in collector.assignments
|
|
|
|
# Verify correct order
|
|
expected_order = ["REGULAR_VAR", "TYPED_VAR", "CACHE", "SENTINEL", "FINAL_VAR"]
|
|
assert collector.assignment_order == expected_order
|
|
|
|
|
|
def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None:
|
|
"""Test that when a class is instantiated, its __init__ method is tracked as a helper.
|
|
|
|
This test verifies the fix for the bug where class constructors were not
|
|
included in the context when only the class instantiation was called
|
|
(not any other methods). This caused LLMs to not know the constructor
|
|
signatures when generating tests.
|
|
"""
|
|
code = '''
|
|
class DataDumper:
|
|
"""A class that dumps data."""
|
|
|
|
def __init__(self, data):
|
|
"""Initialize with data."""
|
|
self.data = data
|
|
|
|
def dump(self):
|
|
"""Dump the data."""
|
|
return self.data
|
|
|
|
|
|
def target_function():
|
|
# Only instantiates DataDumper, doesn't call any other methods
|
|
dumper = DataDumper({"key": "value"})
|
|
return dumper
|
|
'''
|
|
file_path = tmp_path / "test_code.py"
|
|
file_path.write_text(code, encoding="utf-8")
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="target_function", file_path=file_path, parents=[], starting_line=None, ending_line=None
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
|
|
# The __init__ method should be tracked as a helper since DataDumper() instantiates the class
|
|
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
|
|
assert "DataDumper.__init__" in qualified_names, (
|
|
"DataDumper.__init__ should be tracked as a helper when the class is instantiated"
|
|
)
|
|
|
|
# The testgen context should contain the class with __init__ (critical for LLM to know constructor)
|
|
testgen_context = code_ctx.testgen_context.markdown
|
|
assert "class DataDumper:" in testgen_context, "DataDumper class should be in testgen context"
|
|
assert "def __init__(self, data):" in testgen_context, "__init__ method should be included in testgen context"
|
|
|
|
# The hashing context should NOT contain __init__ (excluded for stability)
|
|
hashing_context = code_ctx.hashing_code_context
|
|
assert "__init__" not in hashing_context, "__init__ should NOT be in hashing context (excluded for hash stability)"
|
|
|
|
|
|
def test_class_instantiation_preserves_full_class_in_testgen(tmp_path: Path) -> None:
|
|
"""Test that instantiated classes are fully preserved in testgen context.
|
|
|
|
This is specifically for the unstructured LayoutDumper bug where helper classes
|
|
that were instantiated but had no other methods called were being excluded
|
|
from the testgen context.
|
|
"""
|
|
code = '''
|
|
class LayoutDumper:
|
|
"""Base class for layout dumpers."""
|
|
layout_source: str = "unknown"
|
|
|
|
def __init__(self, layout):
|
|
self._layout = layout
|
|
|
|
def dump(self) -> dict:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class ObjectDetectionLayoutDumper(LayoutDumper):
|
|
"""Specific dumper for object detection layouts."""
|
|
|
|
def __init__(self, layout):
|
|
super().__init__(layout)
|
|
|
|
def dump(self) -> dict:
|
|
return {"type": "object_detection", "layout": self._layout}
|
|
|
|
|
|
def dump_layout(layout_type, layout):
|
|
"""Dump a layout based on its type."""
|
|
if layout_type == "object_detection":
|
|
dumper = ObjectDetectionLayoutDumper(layout)
|
|
else:
|
|
dumper = LayoutDumper(layout)
|
|
return dumper.dump()
|
|
'''
|
|
file_path = tmp_path / "test_code.py"
|
|
file_path.write_text(code, encoding="utf-8")
|
|
opt = Optimizer(
|
|
Namespace(
|
|
project_root=file_path.parent.resolve(),
|
|
disable_telemetry=True,
|
|
tests_root="tests",
|
|
test_framework="pytest",
|
|
pytest_cmd="pytest",
|
|
experiment_id=None,
|
|
test_project_root=Path().resolve(),
|
|
)
|
|
)
|
|
function_to_optimize = FunctionToOptimize(
|
|
function_name="dump_layout", file_path=file_path, parents=[], starting_line=None, ending_line=None
|
|
)
|
|
|
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
|
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
|
|
|
|
# Both class __init__ methods should be tracked as helpers
|
|
assert "ObjectDetectionLayoutDumper.__init__" in qualified_names, (
|
|
"ObjectDetectionLayoutDumper.__init__ should be tracked"
|
|
)
|
|
assert "LayoutDumper.__init__" in qualified_names, "LayoutDumper.__init__ should be tracked"
|
|
|
|
# The testgen context should include both classes with their __init__ methods
|
|
testgen_context = code_ctx.testgen_context.markdown
|
|
assert "class LayoutDumper:" in testgen_context, "LayoutDumper should be in testgen context"
|
|
assert "class ObjectDetectionLayoutDumper" in testgen_context, (
|
|
"ObjectDetectionLayoutDumper should be in testgen context"
|
|
)
|
|
|
|
# Both __init__ methods should be in the testgen context (so LLM knows constructor signatures)
|
|
assert testgen_context.count("def __init__") >= 2, "Both __init__ methods should be in testgen context"
|
|
|
|
|
|
def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None:
|
|
"""Test that get_imported_class_definitions extracts class definitions from project modules."""
|
|
# Create a package structure with two modules
|
|
package_dir = tmp_path / "mypackage"
|
|
package_dir.mkdir()
|
|
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
|
|
|
# Create a module with a class definition (simulating Element-like class)
|
|
elements_code = '''
|
|
import abc
|
|
|
|
class Element(abc.ABC):
|
|
"""An element in the document."""
|
|
|
|
def __init__(self, element_id: str = None):
|
|
self._element_id = element_id
|
|
self.text = ""
|
|
|
|
def __str__(self):
|
|
return self.text
|
|
|
|
|
|
class Text(Element):
|
|
"""A text element."""
|
|
|
|
def __init__(self, text: str, element_id: str = None):
|
|
super().__init__(element_id)
|
|
self.text = text
|
|
'''
|
|
elements_path = package_dir / "elements.py"
|
|
elements_path.write_text(elements_code, encoding="utf-8")
|
|
|
|
# Create another module that imports from elements
|
|
chunking_code = """
|
|
from mypackage.elements import Element
|
|
|
|
class PreChunk:
|
|
def __init__(self, elements: list[Element]):
|
|
self._elements = elements
|
|
|
|
class Accumulator:
|
|
def will_fit(self, chunk: PreChunk) -> bool:
|
|
return True
|
|
"""
|
|
chunking_path = package_dir / "chunking.py"
|
|
chunking_path.write_text(chunking_code, encoding="utf-8")
|
|
|
|
# Create CodeStringsMarkdown from the chunking module (simulating testgen context)
|
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=chunking_code, file_path=chunking_path)])
|
|
|
|
# Call get_imported_class_definitions
|
|
result = get_imported_class_definitions(context, tmp_path)
|
|
|
|
# Verify Element class was extracted
|
|
assert len(result.code_strings) == 1, "Should extract exactly one class (Element)"
|
|
extracted_code = result.code_strings[0].code
|
|
|
|
# Verify the extracted code contains the Element class
|
|
assert "class Element" in extracted_code, "Should contain Element class definition"
|
|
assert "def __init__" in extracted_code, "Should contain __init__ method"
|
|
assert "element_id" in extracted_code, "Should contain constructor parameter"
|
|
assert "import abc" in extracted_code, "Should include necessary imports for base class"
|
|
|
|
|
|
def test_get_imported_class_definitions_skips_existing_definitions(tmp_path: Path) -> None:
|
|
"""Test that get_imported_class_definitions skips classes already defined in context."""
|
|
# Create a package structure
|
|
package_dir = tmp_path / "mypackage"
|
|
package_dir.mkdir()
|
|
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
|
|
|
# Create a module with a class definition
|
|
elements_code = """
|
|
class Element:
|
|
def __init__(self, text: str):
|
|
self.text = text
|
|
"""
|
|
elements_path = package_dir / "elements.py"
|
|
elements_path.write_text(elements_code, encoding="utf-8")
|
|
|
|
# Create code that imports Element but also redefines it locally
|
|
code_with_local_def = """
|
|
from mypackage.elements import Element
|
|
|
|
# Local redefinition (this happens when LLM redefines classes)
|
|
class Element:
|
|
def __init__(self, text: str):
|
|
self.text = text
|
|
|
|
class User:
|
|
def process(self, elem: Element):
|
|
pass
|
|
"""
|
|
code_path = package_dir / "user.py"
|
|
code_path.write_text(code_with_local_def, encoding="utf-8")
|
|
|
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=code_with_local_def, file_path=code_path)])
|
|
|
|
# Call get_imported_class_definitions
|
|
result = get_imported_class_definitions(context, tmp_path)
|
|
|
|
# Should NOT extract Element since it's already defined locally
|
|
assert len(result.code_strings) == 0, "Should not extract classes already defined in context"
|
|
|
|
|
|
def test_get_imported_class_definitions_skips_third_party(tmp_path: Path) -> None:
|
|
"""Test that get_imported_class_definitions skips third-party/stdlib imports."""
|
|
# Create a simple package
|
|
package_dir = tmp_path / "mypackage"
|
|
package_dir.mkdir()
|
|
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
|
|
|
# Code with stdlib/third-party imports
|
|
code = """
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
from dataclasses import dataclass
|
|
|
|
class MyClass:
|
|
def __init__(self, path: Path):
|
|
self.path = path
|
|
"""
|
|
code_path = package_dir / "main.py"
|
|
code_path.write_text(code, encoding="utf-8")
|
|
|
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
|
|
|
# Call get_imported_class_definitions
|
|
result = get_imported_class_definitions(context, tmp_path)
|
|
|
|
# Should not extract any classes (Path, Optional, dataclass are stdlib/third-party)
|
|
assert len(result.code_strings) == 0, "Should not extract stdlib/third-party classes"
|
|
|
|
|
|
def test_get_imported_class_definitions_handles_multiple_imports(tmp_path: Path) -> None:
|
|
"""Test that get_imported_class_definitions handles multiple class imports."""
|
|
# Create a package structure
|
|
package_dir = tmp_path / "mypackage"
|
|
package_dir.mkdir()
|
|
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
|
|
|
# Create a module with multiple class definitions
|
|
types_code = """
|
|
class TypeA:
|
|
def __init__(self, value: int):
|
|
self.value = value
|
|
|
|
class TypeB:
|
|
def __init__(self, name: str):
|
|
self.name = name
|
|
|
|
class TypeC:
|
|
def __init__(self):
|
|
pass
|
|
"""
|
|
types_path = package_dir / "types.py"
|
|
types_path.write_text(types_code, encoding="utf-8")
|
|
|
|
# Create code that imports multiple classes
|
|
code = """
|
|
from mypackage.types import TypeA, TypeB
|
|
|
|
class Processor:
|
|
def process(self, a: TypeA, b: TypeB):
|
|
pass
|
|
"""
|
|
code_path = package_dir / "processor.py"
|
|
code_path.write_text(code, encoding="utf-8")
|
|
|
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
|
|
|
# Call get_imported_class_definitions
|
|
result = get_imported_class_definitions(context, tmp_path)
|
|
|
|
# Should extract both TypeA and TypeB (but not TypeC since it's not imported)
|
|
assert len(result.code_strings) == 2, "Should extract exactly two classes (TypeA, TypeB)"
|
|
|
|
all_extracted_code = "\n".join(cs.code for cs in result.code_strings)
|
|
assert "class TypeA" in all_extracted_code, "Should contain TypeA class"
|
|
assert "class TypeB" in all_extracted_code, "Should contain TypeB class"
|
|
assert "class TypeC" not in all_extracted_code, "Should NOT contain TypeC (not imported)"
|
|
|
|
|
|
def test_get_imported_class_definitions_includes_dataclass_decorators(tmp_path: Path) -> None:
|
|
"""Test that get_imported_class_definitions includes decorators when extracting dataclasses."""
|
|
# Create a package structure
|
|
package_dir = tmp_path / "mypackage"
|
|
package_dir.mkdir()
|
|
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
|
|
|
# Create a module with dataclass definitions (like LLMConfig in skyvern)
|
|
models_code = """from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
|
|
@dataclass(frozen=True)
|
|
class LLMConfigBase:
|
|
model_name: str
|
|
required_env_vars: list[str]
|
|
supports_vision: bool
|
|
add_assistant_prefix: bool
|
|
|
|
@dataclass(frozen=True)
|
|
class LLMConfig(LLMConfigBase):
|
|
litellm_params: Optional[dict] = field(default=None)
|
|
max_tokens: int | None = None
|
|
"""
|
|
models_path = package_dir / "models.py"
|
|
models_path.write_text(models_code, encoding="utf-8")
|
|
|
|
# Create code that imports the dataclass
|
|
code = """from mypackage.models import LLMConfig
|
|
|
|
class ConfigRegistry:
|
|
def get_config(self) -> LLMConfig:
|
|
pass
|
|
"""
|
|
code_path = package_dir / "registry.py"
|
|
code_path.write_text(code, encoding="utf-8")
|
|
|
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
|
|
|
# Call get_imported_class_definitions
|
|
result = get_imported_class_definitions(context, tmp_path)
|
|
|
|
# Should extract both LLMConfigBase (base class) and LLMConfig
|
|
assert len(result.code_strings) == 2, "Should extract both LLMConfig and its base class LLMConfigBase"
|
|
|
|
# Combine extracted code to check for all required elements
|
|
all_extracted_code = "\n".join(cs.code for cs in result.code_strings)
|
|
|
|
# Verify the base class is extracted first (for proper inheritance understanding)
|
|
base_class_idx = all_extracted_code.find("class LLMConfigBase")
|
|
derived_class_idx = all_extracted_code.find("class LLMConfig(")
|
|
assert base_class_idx < derived_class_idx, "Base class should appear before derived class"
|
|
|
|
# Verify both classes include @dataclass decorators
|
|
assert all_extracted_code.count("@dataclass(frozen=True)") == 2, (
|
|
"Should include @dataclass decorator for both classes"
|
|
)
|
|
assert "class LLMConfig" in all_extracted_code, "Should contain LLMConfig class definition"
|
|
assert "class LLMConfigBase" in all_extracted_code, "Should contain LLMConfigBase class definition"
|
|
|
|
# Verify imports are included for dataclass-related items
|
|
assert "from dataclasses import" in all_extracted_code, "Should include dataclasses import"
|
|
|
|
|
|
def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(tmp_path: Path) -> None:
|
|
"""Test that extract_imports_for_class includes decorator and type annotation imports."""
|
|
# Create a package structure
|
|
package_dir = tmp_path / "mypackage"
|
|
package_dir.mkdir()
|
|
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
|
|
|
# Create a module with decorated class that uses field() and various type annotations
|
|
models_code = """from dataclasses import dataclass, field
|
|
from typing import Optional, List
|
|
|
|
@dataclass
|
|
class Config:
|
|
name: str
|
|
values: List[int] = field(default_factory=list)
|
|
description: Optional[str] = None
|
|
"""
|
|
models_path = package_dir / "models.py"
|
|
models_path.write_text(models_code, encoding="utf-8")
|
|
|
|
# Create code that imports the class
|
|
code = """from mypackage.models import Config
|
|
|
|
def create_config() -> Config:
|
|
return Config(name="test")
|
|
"""
|
|
code_path = package_dir / "main.py"
|
|
code_path.write_text(code, encoding="utf-8")
|
|
|
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
|
|
|
result = get_imported_class_definitions(context, tmp_path)
|
|
|
|
assert len(result.code_strings) == 1, "Should extract Config class"
|
|
extracted_code = result.code_strings[0].code
|
|
|
|
# The extracted code should include the decorator
|
|
assert "@dataclass" in extracted_code, "Should include @dataclass decorator"
|
|
# The imports should include dataclass and field
|
|
assert "from dataclasses import" in extracted_code, "Should include dataclasses import for decorator"
|
|
|
|
|
|
class TestCollectNamesFromAnnotation:
|
|
"""Tests for the collect_names_from_annotation helper function."""
|
|
|
|
def test_simple_name(self):
|
|
"""Test extracting a simple type name."""
|
|
import ast
|
|
|
|
code = "def f(x: MyClass): pass"
|
|
annotation = ast.parse(code).body[0].args.args[0].annotation
|
|
names: set[str] = set()
|
|
collect_names_from_annotation(annotation, names)
|
|
assert "MyClass" in names
|
|
|
|
def test_subscript_type(self):
|
|
"""Test extracting names from generic types like List[int]."""
|
|
import ast
|
|
|
|
code = "def f(x: List[int]): pass"
|
|
annotation = ast.parse(code).body[0].args.args[0].annotation
|
|
names: set[str] = set()
|
|
collect_names_from_annotation(annotation, names)
|
|
assert "List" in names
|
|
assert "int" in names
|
|
|
|
def test_optional_type(self):
|
|
"""Test extracting names from Optional[MyClass]."""
|
|
import ast
|
|
|
|
code = "def f(x: Optional[MyClass]): pass"
|
|
annotation = ast.parse(code).body[0].args.args[0].annotation
|
|
names: set[str] = set()
|
|
collect_names_from_annotation(annotation, names)
|
|
assert "Optional" in names
|
|
assert "MyClass" in names
|
|
|
|
def test_union_type_with_pipe(self):
|
|
"""Test extracting names from union types with | syntax."""
|
|
import ast
|
|
|
|
code = "def f(x: int | str | None): pass"
|
|
annotation = ast.parse(code).body[0].args.args[0].annotation
|
|
names: set[str] = set()
|
|
collect_names_from_annotation(annotation, names)
|
|
# int | str | None becomes BinOp nodes
|
|
assert "int" in names
|
|
assert "str" in names
|
|
|
|
def test_nested_generic_types(self):
|
|
"""Test extracting names from nested generics like Dict[str, List[MyClass]]."""
|
|
import ast
|
|
|
|
code = "def f(x: Dict[str, List[MyClass]]): pass"
|
|
annotation = ast.parse(code).body[0].args.args[0].annotation
|
|
names: set[str] = set()
|
|
collect_names_from_annotation(annotation, names)
|
|
assert "Dict" in names
|
|
assert "str" in names
|
|
assert "List" in names
|
|
assert "MyClass" in names
|
|
|
|
def test_tuple_annotation(self):
|
|
"""Test extracting names from tuple type hints."""
|
|
import ast
|
|
|
|
code = "def f(x: tuple[int, str, MyClass]): pass"
|
|
annotation = ast.parse(code).body[0].args.args[0].annotation
|
|
names: set[str] = set()
|
|
collect_names_from_annotation(annotation, names)
|
|
assert "tuple" in names
|
|
assert "int" in names
|
|
assert "str" in names
|
|
assert "MyClass" in names
|
|
|
|
|
|
class TestExtractImportsForClass:
|
|
"""Tests for the extract_imports_for_class helper function."""
|
|
|
|
def test_extracts_base_class_imports(self):
|
|
"""Test that base class imports are extracted."""
|
|
import ast
|
|
|
|
module_source = """from abc import ABC
|
|
from mypackage import BaseClass
|
|
|
|
class MyClass(BaseClass, ABC):
|
|
pass
|
|
"""
|
|
tree = ast.parse(module_source)
|
|
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
|
result = extract_imports_for_class(tree, class_node, module_source)
|
|
assert "from abc import ABC" in result
|
|
assert "from mypackage import BaseClass" in result
|
|
|
|
def test_extracts_decorator_imports(self):
|
|
"""Test that decorator imports are extracted."""
|
|
import ast
|
|
|
|
module_source = """from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
|
|
@dataclass
|
|
class MyClass:
|
|
name: str
|
|
"""
|
|
tree = ast.parse(module_source)
|
|
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
|
result = extract_imports_for_class(tree, class_node, module_source)
|
|
assert "from dataclasses import dataclass" in result
|
|
|
|
def test_extracts_type_annotation_imports(self):
|
|
"""Test that type annotation imports are extracted."""
|
|
import ast
|
|
|
|
module_source = """from typing import Optional, List
|
|
from mypackage.models import Config
|
|
|
|
@dataclass
|
|
class MyClass:
|
|
config: Optional[Config]
|
|
items: List[str]
|
|
"""
|
|
tree = ast.parse(module_source)
|
|
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
|
result = extract_imports_for_class(tree, class_node, module_source)
|
|
assert "from typing import Optional, List" in result
|
|
assert "from mypackage.models import Config" in result
|
|
|
|
def test_extracts_field_function_imports(self):
|
|
"""Test that field() function imports are extracted for dataclasses."""
|
|
import ast
|
|
|
|
module_source = """from dataclasses import dataclass, field
|
|
from typing import List
|
|
|
|
@dataclass
|
|
class MyClass:
|
|
items: List[str] = field(default_factory=list)
|
|
"""
|
|
tree = ast.parse(module_source)
|
|
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
|
result = extract_imports_for_class(tree, class_node, module_source)
|
|
assert "from dataclasses import dataclass, field" in result
|
|
|
|
def test_no_duplicate_imports(self):
|
|
"""Test that duplicate imports are not included."""
|
|
import ast
|
|
|
|
module_source = """from typing import Optional
|
|
|
|
@dataclass
|
|
class MyClass:
|
|
field1: Optional[str]
|
|
field2: Optional[int]
|
|
"""
|
|
tree = ast.parse(module_source)
|
|
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
|
result = extract_imports_for_class(tree, class_node, module_source)
|
|
# Should only have one import line even though Optional is used twice
|
|
assert result.count("from typing import Optional") == 1
|
|
|
|
|
|
def test_get_imported_class_definitions_multiple_decorators(tmp_path: Path) -> None:
|
|
"""Test that classes with multiple decorators are extracted correctly."""
|
|
package_dir = tmp_path / "mypackage"
|
|
package_dir.mkdir()
|
|
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
|
|
|
models_code = """from dataclasses import dataclass
|
|
from functools import total_ordering
|
|
|
|
@total_ordering
|
|
@dataclass
|
|
class OrderedConfig:
|
|
name: str
|
|
priority: int
|
|
|
|
def __lt__(self, other):
|
|
return self.priority < other.priority
|
|
"""
|
|
models_path = package_dir / "models.py"
|
|
models_path.write_text(models_code, encoding="utf-8")
|
|
|
|
code = """from mypackage.models import OrderedConfig
|
|
|
|
def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
|
|
return sorted(configs)
|
|
"""
|
|
code_path = package_dir / "main.py"
|
|
code_path.write_text(code, encoding="utf-8")
|
|
|
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
|
|
|
result = get_imported_class_definitions(context, tmp_path)
|
|
|
|
assert len(result.code_strings) == 1
|
|
extracted_code = result.code_strings[0].code
|
|
|
|
# Both decorators should be included
|
|
assert "@total_ordering" in extracted_code, "Should include @total_ordering decorator"
|
|
assert "@dataclass" in extracted_code, "Should include @dataclass decorator"
|
|
assert "class OrderedConfig" in extracted_code
|
|
|
|
|
|
def test_get_imported_class_definitions_extracts_multilevel_inheritance(tmp_path: Path) -> None:
|
|
"""Test that base classes are recursively extracted for multi-level inheritance.
|
|
|
|
This is critical for understanding dataclass constructor signatures, as fields
|
|
from parent classes become required positional arguments in child classes.
|
|
"""
|
|
# Create a package structure
|
|
package_dir = tmp_path / "mypackage"
|
|
package_dir.mkdir()
|
|
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
|
|
|
# Create a module with multi-level inheritance like skyvern's LLM models:
|
|
# GrandParent -> Parent -> Child
|
|
models_code = '''from dataclasses import dataclass, field
|
|
from typing import Optional, Literal
|
|
|
|
@dataclass(frozen=True)
|
|
class GrandParentConfig:
|
|
"""Base config with common fields."""
|
|
model_name: str
|
|
required_env_vars: list[str]
|
|
|
|
@dataclass(frozen=True)
|
|
class ParentConfig(GrandParentConfig):
|
|
"""Intermediate config adding vision support."""
|
|
supports_vision: bool
|
|
add_assistant_prefix: bool
|
|
|
|
@dataclass(frozen=True)
|
|
class ChildConfig(ParentConfig):
|
|
"""Full config with optional parameters."""
|
|
litellm_params: Optional[dict] = field(default=None)
|
|
max_tokens: int | None = None
|
|
temperature: float | None = 0.7
|
|
|
|
@dataclass(frozen=True)
|
|
class RouterConfig(ParentConfig):
|
|
"""Router config branching from ParentConfig."""
|
|
model_list: list
|
|
main_model_group: str
|
|
routing_strategy: Literal["simple", "least-busy"] = "simple"
|
|
'''
|
|
models_path = package_dir / "models.py"
|
|
models_path.write_text(models_code, encoding="utf-8")
|
|
|
|
# Create code that imports only the child classes (not the base classes)
|
|
code = """from mypackage.models import ChildConfig, RouterConfig
|
|
|
|
class ConfigRegistry:
|
|
def get_child_config(self) -> ChildConfig:
|
|
pass
|
|
|
|
def get_router_config(self) -> RouterConfig:
|
|
pass
|
|
"""
|
|
code_path = package_dir / "registry.py"
|
|
code_path.write_text(code, encoding="utf-8")
|
|
|
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
|
|
|
# Call get_imported_class_definitions
|
|
result = get_imported_class_definitions(context, tmp_path)
|
|
|
|
# Should extract 4 classes: GrandParentConfig, ParentConfig, ChildConfig, RouterConfig
|
|
# (all classes needed to understand the full inheritance hierarchy)
|
|
assert len(result.code_strings) == 4, (
|
|
f"Should extract 4 classes (GrandParent, Parent, Child, Router), got {len(result.code_strings)}"
|
|
)
|
|
|
|
# Combine extracted code
|
|
all_extracted_code = "\n".join(cs.code for cs in result.code_strings)
|
|
|
|
# Verify all classes are extracted
|
|
assert "class GrandParentConfig" in all_extracted_code, "Should extract GrandParentConfig base class"
|
|
assert "class ParentConfig(GrandParentConfig)" in all_extracted_code, "Should extract ParentConfig"
|
|
assert "class ChildConfig(ParentConfig)" in all_extracted_code, "Should extract ChildConfig"
|
|
assert "class RouterConfig(ParentConfig)" in all_extracted_code, "Should extract RouterConfig"
|
|
|
|
# Verify classes are ordered correctly (base classes before derived)
|
|
grandparent_idx = all_extracted_code.find("class GrandParentConfig")
|
|
parent_idx = all_extracted_code.find("class ParentConfig(")
|
|
child_idx = all_extracted_code.find("class ChildConfig(")
|
|
router_idx = all_extracted_code.find("class RouterConfig(")
|
|
|
|
assert grandparent_idx < parent_idx, "GrandParentConfig should appear before ParentConfig"
|
|
assert parent_idx < child_idx, "ParentConfig should appear before ChildConfig"
|
|
assert parent_idx < router_idx, "ParentConfig should appear before RouterConfig"
|
|
|
|
# Verify the critical fields are visible for constructor understanding
|
|
assert "model_name: str" in all_extracted_code, "Should include model_name field from GrandParent"
|
|
assert "required_env_vars: list[str]" in all_extracted_code, "Should include required_env_vars field"
|
|
assert "supports_vision: bool" in all_extracted_code, "Should include supports_vision field from Parent"
|
|
assert "litellm_params:" in all_extracted_code, "Should include litellm_params field from Child"
|
|
assert "model_list: list" in all_extracted_code, "Should include model_list field from Router"
|