codeflash/tests/test_get_helper_code.py
Kevin Turcios eceac13fc3 Merge remote-tracking branch 'origin/main' into omni-java
# Conflicts:
#	.claude/rules/architecture.md
#	.claude/rules/code-style.md
#	.github/workflows/claude.yml
#	.github/workflows/duplicate-code-detector.yml
#	codeflash/api/aiservice.py
#	codeflash/cli_cmds/console.py
#	codeflash/cli_cmds/logging_config.py
#	codeflash/code_utils/deduplicate_code.py
#	codeflash/discovery/discover_unit_tests.py
#	codeflash/languages/base.py
#	codeflash/languages/code_replacer.py
#	codeflash/languages/javascript/mocha_runner.py
#	codeflash/languages/javascript/support.py
#	codeflash/languages/python/support.py
#	codeflash/optimization/function_optimizer.py
#	codeflash/verification/parse_test_output.py
#	codeflash/verification/verification_utils.py
#	codeflash/verification/verifier.py
#	packages/codeflash/package-lock.json
#	packages/codeflash/package.json
#	tests/languages/javascript/test_support_dispatch.py
#	tests/test_codeflash_capture.py
#	tests/test_languages/test_javascript_test_runner.py
#	tests/test_multi_file_code_replacement.py
2026-03-04 01:52:32 -05:00

445 lines
15 KiB
Python

import tempfile
from argparse import Namespace
from pathlib import Path
import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import is_successful
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.models.models import FunctionParent, get_code_block_splitter
from codeflash.optimization.optimizer import Optimizer
from codeflash.verification.verification_utils import TestConfig
class HelperClass:
def helper_method(self, a, b, c):
return a + b + c
def OptimizeMe(a, b, c):
return HelperClass().helper_method(a, b, c)
@pytest.mark.skip
def test_get_outside_method_helper() -> None:
file_path = Path(__file__).resolve()
opt = Optimizer(
Namespace(
project_root=str(file_path.parent.resolve()),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
)
)
function_to_optimize = FunctionToOptimize(
function_name="OptimizeMe", file_path=file_path, parents=[], starting_line=None, ending_line=None
)
with open(file_path) as f:
original_code = f.read()
ctx_result = opt.get_code_optimization_context(function_to_optimize, opt.args.project_root, original_code)
if not is_successful(ctx_result):
pytest.fail()
code_context = ctx_result.unwrap()
print("hi")
def test_flavio_typed_code_helper() -> None:
code = '''
_P = ParamSpec("_P")
_KEY_T = TypeVar("_KEY_T")
_STORE_T = TypeVar("_STORE_T")
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
"""Interface for cache backends used by the persistent cache decorator."""
def __init__(self) -> None: ...
def hash_key(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> tuple[str, _KEY_T]: ...
def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401
...
def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401
...
def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ...
def delete(self, *, key: tuple[str, _KEY_T]) -> None: ...
def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ...
def get_cache_or_call(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
lifespan: datetime.timedelta,
) -> Any: # noqa: ANN401
"""
Retrieve the cached results for a function call.
Args:
----
func (Callable[..., _R]): The function to retrieve cached results for.
args (tuple[Any, ...]): The positional arguments passed to the function.
kwargs (dict[str, Any]): The keyword arguments passed to the function.
lifespan (datetime.timedelta): The maximum age of the cached results.
Returns:
-------
_R: The cached results, if available.
"""
if os.environ.get("NO_CACHE"):
return func(*args, **kwargs)
try:
key = self.hash_key(func=func, args=args, kwargs=kwargs)
except: # noqa: E722
# If we can't create a cache key, we should just call the function.
logging.warning("Failed to hash cache key for function: %s", func)
return func(*args, **kwargs)
result_pair = self.get(key=key)
if result_pair is not None:
cached_time, result = result_pair
if not os.environ.get("RE_CACHE") and (
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
):
try:
return self.decode(data=result)
except CacheBackendDecodeError as e:
logging.warning("Failed to decode cache data: %s", e)
# If decoding fails we will treat this as a cache miss.
# This might happens if underlying class definition of the data changes.
self.delete(key=key)
result = func(*args, **kwargs)
try:
self.put(key=key, data=self.encode(data=result))
except CacheBackendEncodeError as e:
logging.warning("Failed to encode cache data: %s", e)
# If encoding fails, we should still return the result.
return result
_P = ParamSpec("_P")
_R = TypeVar("_R")
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
"""
A decorator class that provides persistent caching functionality for a function.
Args:
----
func (Callable[_P, _R]): The function to be decorated.
duration (datetime.timedelta): The duration for which the cached results should be considered valid.
backend (_backend): The backend storage for the cached results.
Attributes:
----------
__wrapped__ (Callable[_P, _R]): The wrapped function.
__duration__ (datetime.timedelta): The duration for which the cached results should be considered valid.
__backend__ (_backend): The backend storage for the cached results.
""" # noqa: E501
__wrapped__: Callable[_P, _R]
__duration__: datetime.timedelta
__backend__: _CacheBackendT
def __init__(
self,
func: Callable[_P, _R],
duration: datetime.timedelta,
) -> None:
self.__wrapped__ = func
self.__duration__ = duration
self.__backend__ = AbstractCacheBackend()
functools.update_wrapper(self, func)
def cache_clear(self) -> None:
"""Clears the cache for the wrapped function."""
self.__backend__.del_func_cache(func=self.__wrapped__)
def no_cache_call(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
"""
Calls the wrapped function without using the cache.
Args:
----
*args (_P.args): Positional arguments for the wrapped function.
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
Returns:
-------
_R: The result of the wrapped function.
"""
return self.__wrapped__(*args, **kwargs)
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
"""
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
Args:
----
*args (_P.args): Positional arguments for the wrapped function.
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
Returns:
-------
_R: The result of the wrapped function.
""" # noqa: E501
if "NO_CACHE" in os.environ:
return self.__wrapped__(*args, **kwargs)
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
return self.__backend__.get_cache_or_call(
func=self.__wrapped__,
args=args,
kwargs=kwargs,
lifespan=self.__duration__,
)
'''
with tempfile.TemporaryDirectory() as tempdir:
tempdir_path = Path(tempdir)
file_path = (tempdir_path / "typed_code_helper.py").resolve()
file_path.write_text(code, encoding="utf-8")
project_root_path = tempdir_path.resolve()
project_root_path = tempdir_path.resolve()
function_to_optimize = FunctionToOptimize(
function_name="__call__",
file_path=file_path,
parents=[FunctionParent(name="_PersistentCache", type="ClassDef")],
starting_line=None,
ending_line=None,
)
test_config = TestConfig(
tests_root="tests",
tests_project_rootdir=Path.cwd(),
project_root_path=project_root_path,
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
with open(file_path) as f:
original_code = f.read()
ctx_result = func_optimizer.get_code_optimization_context()
if not is_successful(ctx_result):
pytest.fail()
code_context = ctx_result.unwrap()
assert code_context.helper_functions[0].qualified_name == "AbstractCacheBackend.get_cache_or_call"
assert (
code_context.testgen_context.flat
== f'''# file: {file_path.relative_to(project_root_path)}
_P = ParamSpec("_P")
_KEY_T = TypeVar("_KEY_T")
_STORE_T = TypeVar("_STORE_T")
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
"""Interface for cache backends used by the persistent cache decorator."""
def __init__(self) -> None: ...
def hash_key(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> tuple[str, _KEY_T]: ...
def encode(self, *, data: Any) -> _STORE_T: # noqa: ANN401
...
def decode(self, *, data: _STORE_T) -> Any: # noqa: ANN401
...
def get(self, *, key: tuple[str, _KEY_T]) -> tuple[datetime.datetime, _STORE_T] | None: ...
def delete(self, *, key: tuple[str, _KEY_T]) -> None: ...
def put(self, *, key: tuple[str, _KEY_T], data: _STORE_T) -> None: ...
def get_cache_or_call(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
lifespan: datetime.timedelta,
) -> Any: # noqa: ANN401
"""
Retrieve the cached results for a function call.
Args:
----
func (Callable[..., _R]): The function to retrieve cached results for.
args (tuple[Any, ...]): The positional arguments passed to the function.
kwargs (dict[str, Any]): The keyword arguments passed to the function.
lifespan (datetime.timedelta): The maximum age of the cached results.
Returns:
-------
_R: The cached results, if available.
"""
if os.environ.get("NO_CACHE"):
return func(*args, **kwargs)
try:
key = self.hash_key(func=func, args=args, kwargs=kwargs)
except: # noqa: E722
# If we can't create a cache key, we should just call the function.
logging.warning("Failed to hash cache key for function: %s", func)
return func(*args, **kwargs)
result_pair = self.get(key=key)
if result_pair is not None:
cached_time, result = result_pair
if not os.environ.get("RE_CACHE") and (
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
):
try:
return self.decode(data=result)
except CacheBackendDecodeError as e:
logging.warning("Failed to decode cache data: %s", e)
# If decoding fails we will treat this as a cache miss.
# This might happens if underlying class definition of the data changes.
self.delete(key=key)
result = func(*args, **kwargs)
try:
self.put(key=key, data=self.encode(data=result))
except CacheBackendEncodeError as e:
logging.warning("Failed to encode cache data: %s", e)
# If encoding fails, we should still return the result.
return result
_P = ParamSpec("_P")
_R = TypeVar("_R")
_CacheBackendT = TypeVar("_CacheBackendT", bound=CacheBackend)
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
"""
A decorator class that provides persistent caching functionality for a function.
Args:
----
func (Callable[_P, _R]): The function to be decorated.
duration (datetime.timedelta): The duration for which the cached results should be considered valid.
backend (_backend): The backend storage for the cached results.
Attributes:
----------
__wrapped__ (Callable[_P, _R]): The wrapped function.
__duration__ (datetime.timedelta): The duration for which the cached results should be considered valid.
__backend__ (_backend): The backend storage for the cached results.
""" # noqa: E501
__wrapped__: Callable[_P, _R]
__duration__: datetime.timedelta
__backend__: _CacheBackendT
def __init__(
self,
func: Callable[_P, _R],
duration: datetime.timedelta,
) -> None:
self.__wrapped__ = func
self.__duration__ = duration
self.__backend__ = AbstractCacheBackend()
functools.update_wrapper(self, func)
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
"""
Calls the wrapped function, either using the cache or bypassing it based on environment variables.
Args:
----
*args (_P.args): Positional arguments for the wrapped function.
**kwargs (_P.kwargs): Keyword arguments for the wrapped function.
Returns:
-------
_R: The result of the wrapped function.
""" # noqa: E501
if "NO_CACHE" in os.environ:
return self.__wrapped__(*args, **kwargs)
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
return self.__backend__.get_cache_or_call(
func=self.__wrapped__,
args=args,
kwargs=kwargs,
lifespan=self.__duration__,
)
'''
)
def test_bubble_sort_deps() -> None:
file_path = (Path(__file__) / ".." / ".." / "code_to_optimize" / "bubble_sort_deps.py").resolve()
function_to_optimize = FunctionToOptimize(
function_name="sorter_deps", file_path=file_path, parents=[], starting_line=None, ending_line=None
)
project_root = file_path.parent.parent.resolve()
test_config = TestConfig(
tests_root=str(file_path.parent / "tests"),
tests_project_rootdir=file_path.parent.resolve(),
project_root_path=project_root,
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
with open(file_path) as f:
original_code = f.read()
ctx_result = func_optimizer.get_code_optimization_context()
if not is_successful(ctx_result):
pytest.fail()
code_context = ctx_result.unwrap()
assert (
code_context.testgen_context.flat
== f"""{get_code_block_splitter(Path("code_to_optimize/bubble_sort_dep1_helper.py"))}
def dep1_comparer(arr, j: int) -> bool:
return arr[j] > arr[j + 1]
{get_code_block_splitter(Path("code_to_optimize/bubble_sort_dep2_swap.py"))}
def dep2_swap(arr, j):
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
{get_code_block_splitter(Path("code_to_optimize/bubble_sort_deps.py"))}
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
def sorter_deps(arr):
for i in range(len(arr)):
for j in range(len(arr) - 1):
if dep1_comparer(arr, j):
dep2_swap(arr, j)
return arr
"""
)
assert len(code_context.helper_functions) == 2
assert (
code_context.helper_functions[0].fully_qualified_name
== "code_to_optimize.bubble_sort_dep1_helper.dep1_comparer"
)
assert code_context.helper_functions[1].fully_qualified_name == "code_to_optimize.bubble_sort_dep2_swap.dep2_swap"