2024-04-09 12:36:43 +00:00
|
|
|
from __future__ import annotations
|
2025-07-31 13:52:11 +00:00
|
|
|
import re
|
2025-06-06 19:30:30 +00:00
|
|
|
import libcst as cst
|
2025-06-14 00:27:45 +00:00
|
|
|
from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder, AddRequestArgument
|
2024-07-09 23:32:23 +00:00
|
|
|
import dataclasses
|
2023-12-30 02:37:49 +00:00
|
|
|
import os
|
2024-07-09 23:32:23 +00:00
|
|
|
from collections import defaultdict
|
2024-06-09 12:30:06 +00:00
|
|
|
from pathlib import Path
|
|
|
|
|
|
2025-01-08 22:56:53 +00:00
|
|
|
from codeflash.code_utils.code_extractor import delete___future___aliased_imports, find_preexisting_objects
|
2024-07-10 03:38:36 +00:00
|
|
|
from codeflash.code_utils.code_replacer import (
|
|
|
|
|
is_zero_diff,
|
|
|
|
|
replace_functions_and_add_imports,
|
|
|
|
|
replace_functions_in_file,
|
|
|
|
|
)
|
2024-10-29 23:39:47 +00:00
|
|
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
2025-08-06 00:33:46 +00:00
|
|
|
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent
|
2025-02-13 08:10:53 +00:00
|
|
|
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
|
|
|
|
from codeflash.verification.verification_utils import TestConfig
|
2023-12-30 02:37:49 +00:00
|
|
|
|
2024-02-07 01:35:13 +00:00
|
|
|
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
|
|
|
|
|
|
2023-12-30 02:37:49 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class JediDefinition:
|
|
|
|
|
type: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
|
class FakeFunctionSource:
|
2024-10-12 22:29:15 +00:00
|
|
|
file_path: Path
|
2024-07-09 23:32:23 +00:00
|
|
|
qualified_name: str
|
|
|
|
|
fully_qualified_name: str
|
|
|
|
|
only_function_name: str
|
|
|
|
|
source_code: str
|
|
|
|
|
jedi_definition: JediDefinition
|
|
|
|
|
|
|
|
|
|
|
2025-04-30 01:34:40 +00:00
|
|
|
class Args:
|
|
|
|
|
disable_imports_sorting = True
|
|
|
|
|
formatter_cmds = ["disabled"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_code_replacement_global_statements():
|
2025-07-25 12:39:47 +00:00
|
|
|
project_root = Path(__file__).parent.parent.resolve()
|
|
|
|
|
code_path = (project_root / "code_to_optimize/bubble_sort_optimized.py").resolve()
|
2025-08-06 00:33:46 +00:00
|
|
|
optimized_code = f"""```python:{code_path.relative_to(project_root)}
|
2025-07-25 12:39:47 +00:00
|
|
|
import numpy as np
|
|
|
|
|
|
2025-04-30 01:34:40 +00:00
|
|
|
inconsequential_var = '123'
|
|
|
|
|
def sorter(arr):
|
2025-08-06 00:33:46 +00:00
|
|
|
return arr.sort()
|
|
|
|
|
```
|
|
|
|
|
"""
|
2025-04-30 01:34:40 +00:00
|
|
|
original_code_str = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").read_text(
|
|
|
|
|
encoding="utf-8"
|
|
|
|
|
)
|
|
|
|
|
code_path.write_text(original_code_str, encoding="utf-8")
|
|
|
|
|
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
|
|
|
|
|
project_root_path = (Path(__file__).parent / "..").resolve()
|
|
|
|
|
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
|
|
|
|
|
test_config = TestConfig(
|
|
|
|
|
tests_root=tests_root,
|
|
|
|
|
tests_project_rootdir=project_root_path,
|
|
|
|
|
project_root_path=project_root_path,
|
|
|
|
|
test_framework="pytest",
|
|
|
|
|
pytest_cmd="pytest",
|
|
|
|
|
)
|
|
|
|
|
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
|
|
|
|
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
|
|
|
|
original_helper_code: dict[Path, str] = {}
|
|
|
|
|
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
|
|
|
|
for helper_function_path in helper_function_paths:
|
|
|
|
|
with helper_function_path.open(encoding="utf8") as f:
|
|
|
|
|
helper_code = f.read()
|
|
|
|
|
original_helper_code[helper_function_path] = helper_code
|
|
|
|
|
func_optimizer.args = Args()
|
|
|
|
|
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
2025-08-05 22:09:42 +00:00
|
|
|
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code
|
2025-04-30 01:34:40 +00:00
|
|
|
)
|
|
|
|
|
final_output = code_path.read_text(encoding="utf-8")
|
|
|
|
|
assert "inconsequential_var = '123'" in final_output
|
|
|
|
|
code_path.unlink(missing_ok=True)
|
|
|
|
|
|
|
|
|
|
|
2024-03-13 09:43:25 +00:00
|
|
|
def test_test_libcst_code_replacement() -> None:
|
2023-12-30 02:37:49 +00:00
|
|
|
optim_code = """import libcst as cst
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
def totally_new_function(value):
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def new_function(self, value):
|
|
|
|
|
return self.name
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return value
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
original_code = """class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
2024-07-30 23:35:32 +00:00
|
|
|
@staticmethod
|
2023-12-30 02:37:49 +00:00
|
|
|
def new_function(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
|
|
|
|
|
print("Hello world")
|
|
|
|
|
"""
|
2024-06-17 02:17:45 +00:00
|
|
|
expected = """class NewClass:
|
2023-12-30 02:37:49 +00:00
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def new_function(self, value):
|
|
|
|
|
return self.name
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
def totally_new_function(value):
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
print("Hello world")
|
|
|
|
|
"""
|
|
|
|
|
|
2024-02-07 01:35:13 +00:00
|
|
|
function_name: str = "NewClass.new_function"
|
2025-04-30 01:34:40 +00:00
|
|
|
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
2025-08-06 19:48:03 +00:00
|
|
|
print(f"Preexisting objects: {preexisting_objects}")
|
2024-06-17 02:17:45 +00:00
|
|
|
new_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code,
|
|
|
|
|
function_names=[function_name],
|
|
|
|
|
optimized_code=optim_code,
|
2024-10-12 22:29:15 +00:00
|
|
|
module_abspath=Path(__file__).resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
preexisting_objects=preexisting_objects,
|
2024-10-12 22:29:15 +00:00
|
|
|
project_root_path=Path(__file__).resolve().parent.resolve(),
|
2024-02-07 15:26:45 +00:00
|
|
|
)
|
|
|
|
|
assert new_code == expected
|
|
|
|
|
|
|
|
|
|
|
2024-03-13 09:43:25 +00:00
|
|
|
def test_test_libcst_code_replacement2() -> None:
|
2024-02-07 15:26:45 +00:00
|
|
|
optim_code = """import libcst as cst
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
def totally_new_function(value):
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
def other_function(st):
|
|
|
|
|
return(st * 2)
|
2024-04-09 12:36:43 +00:00
|
|
|
|
2024-02-07 15:26:45 +00:00
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def new_function(self, value):
|
|
|
|
|
return other_function(self.name)
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return value
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
original_code = """from OtherModule import other_function
|
|
|
|
|
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def new_function(self, value):
|
|
|
|
|
return other_function("I am still old")
|
|
|
|
|
|
|
|
|
|
print("Hello world")
|
|
|
|
|
"""
|
2024-06-17 02:17:45 +00:00
|
|
|
expected = """from OtherModule import other_function
|
2024-02-07 15:26:45 +00:00
|
|
|
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def new_function(self, value):
|
|
|
|
|
return other_function(self.name)
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
def totally_new_function(value):
|
|
|
|
|
return value
|
|
|
|
|
|
2025-01-08 22:56:53 +00:00
|
|
|
def other_function(st):
|
|
|
|
|
return(st * 2)
|
|
|
|
|
|
2024-02-07 15:26:45 +00:00
|
|
|
print("Hello world")
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
function_name: str = "NewClass.new_function"
|
2025-04-30 01:34:40 +00:00
|
|
|
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
2024-06-17 02:17:45 +00:00
|
|
|
new_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code,
|
|
|
|
|
function_names=[function_name],
|
|
|
|
|
optimized_code=optim_code,
|
2024-10-12 22:29:15 +00:00
|
|
|
module_abspath=Path(__file__).resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
preexisting_objects=preexisting_objects,
|
2024-10-12 22:29:15 +00:00
|
|
|
project_root_path=Path(__file__).resolve().parent.resolve(),
|
2024-02-07 15:26:45 +00:00
|
|
|
)
|
|
|
|
|
assert new_code == expected
|
|
|
|
|
|
|
|
|
|
|
2024-03-13 09:43:25 +00:00
|
|
|
def test_test_libcst_code_replacement3() -> None:
|
2024-02-07 15:26:45 +00:00
|
|
|
optim_code = """import libcst as cst
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
def totally_new_function(value):
|
|
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
def other_function(st):
|
|
|
|
|
return(st * 2)
|
|
|
|
|
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
2024-06-17 02:17:45 +00:00
|
|
|
def new_function(self, value: cst.Name):
|
2024-02-07 15:26:45 +00:00
|
|
|
return other_function(self.name)
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return value
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
original_code = """import libcst as cst
|
|
|
|
|
from typing import Mandatory
|
|
|
|
|
|
|
|
|
|
print("Au revoir")
|
|
|
|
|
|
|
|
|
|
def yet_another_function(values):
|
|
|
|
|
return len(values)
|
|
|
|
|
|
|
|
|
|
def other_function(st):
|
|
|
|
|
return(st + st)
|
|
|
|
|
|
|
|
|
|
print("Salut monde")
|
|
|
|
|
"""
|
2024-06-17 02:17:45 +00:00
|
|
|
expected = """from typing import Mandatory
|
2024-02-07 15:26:45 +00:00
|
|
|
|
|
|
|
|
print("Au revoir")
|
|
|
|
|
|
|
|
|
|
def yet_another_function(values):
|
|
|
|
|
return len(values)
|
|
|
|
|
|
|
|
|
|
def other_function(st):
|
|
|
|
|
return(st * 2)
|
|
|
|
|
|
2025-01-08 22:56:53 +00:00
|
|
|
def totally_new_function(value):
|
|
|
|
|
return value
|
|
|
|
|
|
2024-02-07 15:26:45 +00:00
|
|
|
print("Salut monde")
|
|
|
|
|
"""
|
|
|
|
|
|
2025-01-08 22:56:53 +00:00
|
|
|
function_names: list[str] = ["other_function"]
|
2025-04-30 01:34:40 +00:00
|
|
|
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
2024-06-17 02:17:45 +00:00
|
|
|
new_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code,
|
|
|
|
|
function_names=function_names,
|
|
|
|
|
optimized_code=optim_code,
|
2024-10-12 22:29:15 +00:00
|
|
|
module_abspath=Path(__file__).resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
preexisting_objects=preexisting_objects,
|
2024-10-12 22:29:15 +00:00
|
|
|
project_root_path=Path(__file__).resolve().parent.resolve(),
|
2024-02-07 15:26:45 +00:00
|
|
|
)
|
|
|
|
|
assert new_code == expected
|
|
|
|
|
|
|
|
|
|
|
2024-03-13 09:43:25 +00:00
|
|
|
def test_test_libcst_code_replacement4() -> None:
|
2024-02-07 15:26:45 +00:00
|
|
|
optim_code = """import libcst as cst
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
def totally_new_function(value):
|
|
|
|
|
return value
|
2024-04-09 12:36:43 +00:00
|
|
|
|
2024-06-17 02:17:45 +00:00
|
|
|
def yet_another_function(values: Optional[str]):
|
2024-02-07 15:26:45 +00:00
|
|
|
return len(values) + 2
|
|
|
|
|
|
|
|
|
|
def other_function(st):
|
|
|
|
|
return(st * 2)
|
|
|
|
|
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def new_function(self, value):
|
|
|
|
|
return other_function(self.name)
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return value
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
original_code = """import libcst as cst
|
|
|
|
|
from typing import Mandatory
|
|
|
|
|
|
|
|
|
|
print("Au revoir")
|
|
|
|
|
|
|
|
|
|
def yet_another_function(values):
|
|
|
|
|
return len(values)
|
|
|
|
|
|
|
|
|
|
def other_function(st):
|
|
|
|
|
return(st + st)
|
|
|
|
|
|
|
|
|
|
print("Salut monde")
|
|
|
|
|
"""
|
2024-07-30 10:07:21 +00:00
|
|
|
expected = """from typing import Mandatory
|
2024-02-07 15:26:45 +00:00
|
|
|
|
|
|
|
|
print("Au revoir")
|
2024-04-09 12:36:43 +00:00
|
|
|
|
2024-07-30 10:07:21 +00:00
|
|
|
def yet_another_function(values):
|
2024-02-07 15:26:45 +00:00
|
|
|
return len(values) + 2
|
|
|
|
|
|
|
|
|
|
def other_function(st):
|
|
|
|
|
return(st * 2)
|
|
|
|
|
|
2025-01-08 22:56:53 +00:00
|
|
|
def totally_new_function(value):
|
|
|
|
|
return value
|
|
|
|
|
|
2024-02-07 15:26:45 +00:00
|
|
|
print("Salut monde")
|
|
|
|
|
"""
|
|
|
|
|
|
2025-01-08 22:56:53 +00:00
|
|
|
function_names: list[str] = ["yet_another_function", "other_function"]
|
2025-04-30 01:34:40 +00:00
|
|
|
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
2024-06-17 02:17:45 +00:00
|
|
|
new_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code,
|
|
|
|
|
function_names=function_names,
|
|
|
|
|
optimized_code=optim_code,
|
2024-10-12 22:29:15 +00:00
|
|
|
module_abspath=Path(__file__).resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
preexisting_objects=preexisting_objects,
|
2024-10-12 22:29:15 +00:00
|
|
|
project_root_path=Path(__file__).resolve().parent.resolve(),
|
2024-02-07 01:35:13 +00:00
|
|
|
)
|
|
|
|
|
assert new_code == expected
|
2024-02-08 23:52:49 +00:00
|
|
|
|
|
|
|
|
|
2024-03-13 09:43:25 +00:00
|
|
|
def test_test_libcst_code_replacement5() -> None:
|
2024-07-30 23:35:32 +00:00
|
|
|
optim_code = """@lru_cache(17)
|
|
|
|
|
def sorter_deps(arr: list[int]) -> list[int]:
|
2024-02-08 23:52:49 +00:00
|
|
|
supersort(badsort(arr))
|
|
|
|
|
return arr
|
|
|
|
|
|
|
|
|
|
def badsort(ploc):
|
|
|
|
|
donothing(ploc)
|
2024-04-09 12:36:43 +00:00
|
|
|
|
2024-02-08 23:52:49 +00:00
|
|
|
def supersort(doink):
|
|
|
|
|
for i in range(len(doink)):
|
|
|
|
|
fix(doink, i)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
original_code = """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
|
|
|
|
|
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
|
|
|
|
|
|
|
|
|
|
def sorter_deps(arr):
|
|
|
|
|
for i in range(len(arr)):
|
|
|
|
|
for j in range(len(arr) - 1):
|
|
|
|
|
if dep1_comparer(arr, j):
|
|
|
|
|
dep2_swap(arr, j)
|
|
|
|
|
return arr
|
|
|
|
|
"""
|
|
|
|
|
expected = """from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
|
|
|
|
|
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
|
2024-07-30 10:07:21 +00:00
|
|
|
|
2024-07-30 23:35:32 +00:00
|
|
|
@lru_cache(17)
|
2024-02-08 23:52:49 +00:00
|
|
|
def sorter_deps(arr):
|
|
|
|
|
supersort(badsort(arr))
|
|
|
|
|
return arr
|
|
|
|
|
|
|
|
|
|
def badsort(ploc):
|
|
|
|
|
donothing(ploc)
|
2024-04-09 12:36:43 +00:00
|
|
|
|
2024-02-08 23:52:49 +00:00
|
|
|
def supersort(doink):
|
|
|
|
|
for i in range(len(doink)):
|
|
|
|
|
fix(doink, i)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
function_names: list[str] = ["sorter_deps"]
|
2025-04-30 01:34:40 +00:00
|
|
|
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
2024-06-17 02:17:45 +00:00
|
|
|
new_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code,
|
|
|
|
|
function_names=function_names,
|
|
|
|
|
optimized_code=optim_code,
|
2024-10-12 22:29:15 +00:00
|
|
|
module_abspath=Path(__file__).resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
preexisting_objects=preexisting_objects,
|
2024-10-12 22:29:15 +00:00
|
|
|
project_root_path=Path(__file__).resolve().parent.resolve(),
|
2024-02-08 23:52:49 +00:00
|
|
|
)
|
|
|
|
|
assert new_code == expected
|
2024-02-11 06:50:27 +00:00
|
|
|
|
|
|
|
|
|
2024-03-13 09:43:25 +00:00
|
|
|
def test_test_libcst_code_replacement6() -> None:
|
2024-02-11 06:50:27 +00:00
|
|
|
optim_code = """import libcst as cst
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
def other_function(st):
|
|
|
|
|
return(st * blob(st))
|
|
|
|
|
|
|
|
|
|
def blob(st):
|
|
|
|
|
return(st * 2)
|
|
|
|
|
"""
|
|
|
|
|
original_code_main = """import libcst as cst
|
|
|
|
|
from typing import Mandatory
|
2024-05-19 02:00:23 +00:00
|
|
|
from helper import blob
|
2024-02-11 06:50:27 +00:00
|
|
|
|
|
|
|
|
print("Au revoir")
|
|
|
|
|
|
|
|
|
|
def yet_another_function(values):
|
|
|
|
|
return len(values)
|
|
|
|
|
|
|
|
|
|
def other_function(st):
|
|
|
|
|
return(st + blob(st))
|
|
|
|
|
|
|
|
|
|
print("Salut monde")
|
|
|
|
|
"""
|
|
|
|
|
|
2024-05-19 02:00:23 +00:00
|
|
|
original_code_helper = """import numpy as np
|
2024-02-11 06:50:27 +00:00
|
|
|
|
|
|
|
|
print("Cool")
|
|
|
|
|
|
|
|
|
|
def blob(values):
|
|
|
|
|
return len(values)
|
|
|
|
|
|
|
|
|
|
def blab(st):
|
|
|
|
|
return(st + st)
|
|
|
|
|
|
|
|
|
|
print("Not cool")
|
|
|
|
|
"""
|
2024-06-17 02:17:45 +00:00
|
|
|
expected_main = """from typing import Mandatory
|
2024-05-19 02:00:23 +00:00
|
|
|
from helper import blob
|
2024-02-11 06:50:27 +00:00
|
|
|
|
|
|
|
|
print("Au revoir")
|
|
|
|
|
|
|
|
|
|
def yet_another_function(values):
|
|
|
|
|
return len(values)
|
|
|
|
|
|
|
|
|
|
def other_function(st):
|
|
|
|
|
return(st * blob(st))
|
|
|
|
|
|
|
|
|
|
print("Salut monde")
|
|
|
|
|
"""
|
|
|
|
|
|
2024-06-17 02:17:45 +00:00
|
|
|
expected_helper = """import numpy as np
|
2024-02-11 06:50:27 +00:00
|
|
|
|
|
|
|
|
print("Cool")
|
|
|
|
|
|
2024-07-30 10:07:21 +00:00
|
|
|
def blob(values):
|
2024-02-11 06:50:27 +00:00
|
|
|
return(st * 2)
|
|
|
|
|
|
|
|
|
|
def blab(st):
|
|
|
|
|
return(st + st)
|
|
|
|
|
|
|
|
|
|
print("Not cool")
|
|
|
|
|
"""
|
2025-03-14 01:52:11 +00:00
|
|
|
preexisting_objects = find_preexisting_objects(original_code_main) | find_preexisting_objects(original_code_helper)
|
2024-06-17 02:17:45 +00:00
|
|
|
new_main_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code_main,
|
|
|
|
|
function_names=["other_function"],
|
|
|
|
|
optimized_code=optim_code,
|
2024-10-12 22:29:15 +00:00
|
|
|
module_abspath=Path(__file__).resolve(),
|
2025-01-08 22:56:53 +00:00
|
|
|
preexisting_objects=preexisting_objects,
|
2024-10-12 22:29:15 +00:00
|
|
|
project_root_path=Path(__file__).resolve().parent.resolve(),
|
2024-02-11 06:50:27 +00:00
|
|
|
)
|
|
|
|
|
assert new_main_code == expected_main
|
|
|
|
|
|
2024-06-17 02:17:45 +00:00
|
|
|
new_helper_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code_helper,
|
|
|
|
|
function_names=["blob"],
|
|
|
|
|
optimized_code=optim_code,
|
2024-10-12 22:29:15 +00:00
|
|
|
module_abspath=Path(__file__).resolve(),
|
2025-01-08 22:56:53 +00:00
|
|
|
preexisting_objects=preexisting_objects,
|
2024-10-12 22:29:15 +00:00
|
|
|
project_root_path=Path(__file__).resolve().parent.resolve(),
|
2024-02-11 06:50:27 +00:00
|
|
|
)
|
2024-05-19 02:00:23 +00:00
|
|
|
assert new_helper_code == expected_helper
|
2024-02-14 02:35:02 +00:00
|
|
|
|
|
|
|
|
|
2024-03-13 09:43:25 +00:00
|
|
|
def test_test_libcst_code_replacement7() -> None:
|
2024-02-14 02:35:02 +00:00
|
|
|
optim_code = """@register_deserializable
|
|
|
|
|
class CacheSimilarityEvalConfig(BaseConfig):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
strategy: Optional[str] = "distance",
|
|
|
|
|
max_distance: Optional[float] = 1.0,
|
|
|
|
|
positive: Optional[bool] = False,
|
|
|
|
|
):
|
|
|
|
|
self.strategy = strategy
|
|
|
|
|
self.max_distance = max_distance
|
|
|
|
|
self.positive = positive
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def from_config(config: Optional[dict[str, Any]]):
|
|
|
|
|
if config is None:
|
|
|
|
|
return CacheSimilarityEvalConfig()
|
|
|
|
|
|
|
|
|
|
strategy = config.get("strategy", "distance")
|
|
|
|
|
max_distance = config.get("max_distance", 1.0)
|
|
|
|
|
positive = config.get("positive", False)
|
|
|
|
|
|
|
|
|
|
return CacheSimilarityEvalConfig(strategy, max_distance, positive)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
original_code = """from typing import Any, Optional
|
|
|
|
|
|
|
|
|
|
from embedchain.config.base_config import BaseConfig
|
|
|
|
|
from embedchain.helpers.json_serializable import register_deserializable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_deserializable
|
|
|
|
|
class CacheSimilarityEvalConfig(BaseConfig):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
strategy: Optional[str] = "distance",
|
|
|
|
|
max_distance: Optional[float] = 1.0,
|
|
|
|
|
positive: Optional[bool] = False,
|
|
|
|
|
):
|
|
|
|
|
self.strategy = strategy
|
|
|
|
|
self.max_distance = max_distance
|
|
|
|
|
self.positive = positive
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def from_config(config: Optional[dict[str, Any]]):
|
|
|
|
|
if config is None:
|
|
|
|
|
return CacheSimilarityEvalConfig()
|
|
|
|
|
else:
|
|
|
|
|
return CacheSimilarityEvalConfig(
|
|
|
|
|
strategy=config.get("strategy", "distance"),
|
|
|
|
|
max_distance=config.get("max_distance", 1.0),
|
|
|
|
|
positive=config.get("positive", False),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_deserializable
|
|
|
|
|
class CacheInitConfig(BaseConfig):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
similarity_threshold: Optional[float] = 0.8,
|
|
|
|
|
auto_flush: Optional[int] = 20,
|
|
|
|
|
):
|
|
|
|
|
if similarity_threshold < 0 or similarity_threshold > 1:
|
|
|
|
|
raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
|
|
|
|
|
|
|
|
|
|
self.similarity_threshold = similarity_threshold
|
|
|
|
|
self.auto_flush = auto_flush
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def from_config(config: Optional[dict[str, Any]]):
|
|
|
|
|
if config is None:
|
|
|
|
|
return CacheInitConfig()
|
|
|
|
|
else:
|
|
|
|
|
return CacheInitConfig(
|
|
|
|
|
similarity_threshold=config.get("similarity_threshold", 0.8),
|
|
|
|
|
auto_flush=config.get("auto_flush", 20),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_deserializable
|
|
|
|
|
class CacheConfig(BaseConfig):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
similarity_eval_config: Optional[CacheSimilarityEvalConfig] = CacheSimilarityEvalConfig(),
|
|
|
|
|
init_config: Optional[CacheInitConfig] = CacheInitConfig(),
|
|
|
|
|
):
|
|
|
|
|
self.similarity_eval_config = similarity_eval_config
|
|
|
|
|
self.init_config = init_config
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def from_config(config: Optional[dict[str, Any]]):
|
|
|
|
|
if config is None:
|
|
|
|
|
return CacheConfig()
|
|
|
|
|
else:
|
|
|
|
|
return CacheConfig(
|
|
|
|
|
similarity_eval_config=CacheSimilarityEvalConfig.from_config(config.get("similarity_evaluation", {})),
|
|
|
|
|
init_config=CacheInitConfig.from_config(config.get("init_config", {})),
|
|
|
|
|
)
|
|
|
|
|
"""
|
|
|
|
|
expected = """from typing import Any, Optional
|
|
|
|
|
|
|
|
|
|
from embedchain.config.base_config import BaseConfig
|
|
|
|
|
from embedchain.helpers.json_serializable import register_deserializable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_deserializable
|
|
|
|
|
class CacheSimilarityEvalConfig(BaseConfig):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
strategy: Optional[str] = "distance",
|
|
|
|
|
max_distance: Optional[float] = 1.0,
|
|
|
|
|
positive: Optional[bool] = False,
|
|
|
|
|
):
|
|
|
|
|
self.strategy = strategy
|
|
|
|
|
self.max_distance = max_distance
|
|
|
|
|
self.positive = positive
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def from_config(config: Optional[dict[str, Any]]):
|
|
|
|
|
if config is None:
|
|
|
|
|
return CacheSimilarityEvalConfig()
|
|
|
|
|
|
|
|
|
|
strategy = config.get("strategy", "distance")
|
|
|
|
|
max_distance = config.get("max_distance", 1.0)
|
|
|
|
|
positive = config.get("positive", False)
|
|
|
|
|
|
|
|
|
|
return CacheSimilarityEvalConfig(strategy, max_distance, positive)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_deserializable
|
|
|
|
|
class CacheInitConfig(BaseConfig):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
similarity_threshold: Optional[float] = 0.8,
|
|
|
|
|
auto_flush: Optional[int] = 20,
|
|
|
|
|
):
|
|
|
|
|
if similarity_threshold < 0 or similarity_threshold > 1:
|
|
|
|
|
raise ValueError(f"similarity_threshold {similarity_threshold} should be between 0 and 1")
|
|
|
|
|
|
|
|
|
|
self.similarity_threshold = similarity_threshold
|
|
|
|
|
self.auto_flush = auto_flush
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def from_config(config: Optional[dict[str, Any]]):
|
|
|
|
|
if config is None:
|
|
|
|
|
return CacheInitConfig()
|
|
|
|
|
else:
|
|
|
|
|
return CacheInitConfig(
|
|
|
|
|
similarity_threshold=config.get("similarity_threshold", 0.8),
|
|
|
|
|
auto_flush=config.get("auto_flush", 20),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_deserializable
|
|
|
|
|
class CacheConfig(BaseConfig):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
similarity_eval_config: Optional[CacheSimilarityEvalConfig] = CacheSimilarityEvalConfig(),
|
|
|
|
|
init_config: Optional[CacheInitConfig] = CacheInitConfig(),
|
|
|
|
|
):
|
|
|
|
|
self.similarity_eval_config = similarity_eval_config
|
|
|
|
|
self.init_config = init_config
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def from_config(config: Optional[dict[str, Any]]):
|
|
|
|
|
if config is None:
|
|
|
|
|
return CacheConfig()
|
|
|
|
|
else:
|
|
|
|
|
return CacheConfig(
|
|
|
|
|
similarity_eval_config=CacheSimilarityEvalConfig.from_config(config.get("similarity_evaluation", {})),
|
|
|
|
|
init_config=CacheInitConfig.from_config(config.get("init_config", {})),
|
|
|
|
|
)
|
|
|
|
|
"""
|
|
|
|
|
function_names: list[str] = ["CacheSimilarityEvalConfig.from_config"]
|
2025-04-30 01:34:40 +00:00
|
|
|
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
2025-01-08 22:56:53 +00:00
|
|
|
|
2024-06-17 02:17:45 +00:00
|
|
|
new_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code,
|
|
|
|
|
function_names=function_names,
|
|
|
|
|
optimized_code=optim_code,
|
2024-10-12 22:29:15 +00:00
|
|
|
module_abspath=Path(__file__).resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
preexisting_objects=preexisting_objects,
|
2024-10-12 22:29:15 +00:00
|
|
|
project_root_path=Path(__file__).resolve().parent.resolve(),
|
2024-02-14 02:35:02 +00:00
|
|
|
)
|
|
|
|
|
assert new_code == expected
|
2024-03-13 09:43:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_test_libcst_code_replacement8() -> None:
|
|
|
|
|
optim_code = '''class _EmbeddingDistanceChainMixin(Chain):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
|
|
|
|
|
"""Compute the Hamming distance between two vectors.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
a (np.ndarray): The first vector.
|
|
|
|
|
b (np.ndarray): The second vector.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
np.floating: The Hamming distance.
|
|
|
|
|
"""
|
|
|
|
|
return np.sum(a != b) / a.size
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
original_code = '''class _EmbeddingDistanceChainMixin(Chain):
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
"""Permit embeddings to go unvalidated."""
|
|
|
|
|
|
|
|
|
|
arbitrary_types_allowed: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
|
|
|
|
|
"""Compute the Hamming distance between two vectors.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
a (np.ndarray): The first vector.
|
|
|
|
|
b (np.ndarray): The second vector.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
np.floating: The Hamming distance.
|
|
|
|
|
"""
|
|
|
|
|
return np.mean(a != b)
|
|
|
|
|
'''
|
|
|
|
|
expected = '''class _EmbeddingDistanceChainMixin(Chain):
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
"""Permit embeddings to go unvalidated."""
|
|
|
|
|
|
|
|
|
|
arbitrary_types_allowed: bool = True
|
2024-07-30 10:07:21 +00:00
|
|
|
|
|
|
|
|
|
2024-03-13 09:43:25 +00:00
|
|
|
@staticmethod
|
|
|
|
|
def _hamming_distance(a: np.ndarray, b: np.ndarray) -> np.floating:
|
|
|
|
|
"""Compute the Hamming distance between two vectors.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
a (np.ndarray): The first vector.
|
|
|
|
|
b (np.ndarray): The second vector.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
np.floating: The Hamming distance.
|
|
|
|
|
"""
|
|
|
|
|
return np.sum(a != b) / a.size
|
|
|
|
|
'''
|
|
|
|
|
function_names: list[str] = ["_EmbeddingDistanceChainMixin._hamming_distance"]
|
2025-04-30 01:34:40 +00:00
|
|
|
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
2024-06-17 02:17:45 +00:00
|
|
|
new_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code,
|
|
|
|
|
function_names=function_names,
|
|
|
|
|
optimized_code=optim_code,
|
2024-10-12 22:29:15 +00:00
|
|
|
module_abspath=Path(__file__).resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
preexisting_objects=preexisting_objects,
|
2024-10-12 22:29:15 +00:00
|
|
|
project_root_path=Path(__file__).resolve().parent.resolve(),
|
2024-04-10 01:54:55 +00:00
|
|
|
)
|
|
|
|
|
assert new_code == expected
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_test_libcst_code_replacement9() -> None:
|
|
|
|
|
optim_code = """import libcst as cst
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
2024-06-17 02:17:45 +00:00
|
|
|
def totally_new_function(value: Optional[str]):
|
2024-04-10 01:54:55 +00:00
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = str(name)
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return self.name
|
|
|
|
|
def new_function2(value):
|
2024-06-17 02:17:45 +00:00
|
|
|
return cst.ensure_type(value, str)
|
2024-04-10 01:54:55 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
original_code = """class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
|
|
|
|
|
print("Hello world")
|
|
|
|
|
"""
|
|
|
|
|
expected = """import libcst as cst
|
|
|
|
|
from typing import Optional
|
2024-06-17 02:17:45 +00:00
|
|
|
|
2024-04-10 01:54:55 +00:00
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = str(name)
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
2024-06-17 02:17:45 +00:00
|
|
|
return cst.ensure_type(value, str)
|
2024-04-10 01:54:55 +00:00
|
|
|
|
2024-06-17 02:17:45 +00:00
|
|
|
def totally_new_function(value: Optional[str]):
|
2024-04-10 01:54:55 +00:00
|
|
|
return value
|
|
|
|
|
|
|
|
|
|
print("Hello world")
|
|
|
|
|
"""
|
|
|
|
|
function_name: str = "NewClass.__init__"
|
2025-04-30 01:34:40 +00:00
|
|
|
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
2024-06-17 02:17:45 +00:00
|
|
|
new_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code,
|
|
|
|
|
function_names=[function_name],
|
|
|
|
|
optimized_code=optim_code,
|
2024-10-12 22:29:15 +00:00
|
|
|
module_abspath=Path(__file__).resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
preexisting_objects=preexisting_objects,
|
2024-10-12 22:29:15 +00:00
|
|
|
project_root_path=Path(__file__).resolve().parent.resolve(),
|
2024-03-13 09:43:25 +00:00
|
|
|
)
|
|
|
|
|
assert new_code == expected
|
2024-06-09 12:30:06 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class HelperClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
|
|
|
|
|
def innocent_bystander(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def helper_method(self):
|
|
|
|
|
return self.name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
|
|
|
|
|
def main_method(self):
|
|
|
|
|
return HelperClass(self.name).helper_method()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_code_replacement10() -> None:
|
2025-05-01 01:14:00 +00:00
|
|
|
get_code_output = """from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
class HelperClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
|
|
|
|
|
def helper_method(self):
|
|
|
|
|
return self.name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MainClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
|
|
|
|
|
def main_method(self):
|
2025-06-06 05:40:09 +00:00
|
|
|
return HelperClass(self.name).helper_method()
|
|
|
|
|
"""
|
2024-06-09 12:30:06 +00:00
|
|
|
file_path = Path(__file__).resolve()
|
2024-06-17 02:17:45 +00:00
|
|
|
func_top_optimize = FunctionToOptimize(
|
2024-10-25 22:45:44 +00:00
|
|
|
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
|
2024-06-17 02:17:45 +00:00
|
|
|
)
|
2025-02-13 08:10:53 +00:00
|
|
|
test_config = TestConfig(
|
|
|
|
|
tests_root=file_path.parent,
|
|
|
|
|
tests_project_rootdir=file_path.parent,
|
|
|
|
|
project_root_path=file_path.parent,
|
|
|
|
|
test_framework="pytest",
|
|
|
|
|
pytest_cmd="pytest",
|
|
|
|
|
)
|
|
|
|
|
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
|
|
|
|
|
code_context = func_optimizer.get_code_optimization_context().unwrap()
|
2025-09-25 00:31:05 +00:00
|
|
|
assert code_context.testgen_context.rstrip() == get_code_output.rstrip()
|
2024-06-18 01:27:13 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_code_replacement11() -> None:
|
2024-06-21 23:43:43 +00:00
|
|
|
optim_code = '''class Fu():
|
|
|
|
|
def foo(self) -> dict[str, str]:
|
|
|
|
|
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar() + 1)}
|
2024-06-18 01:27:13 +00:00
|
|
|
return payload
|
|
|
|
|
|
2024-06-21 23:43:43 +00:00
|
|
|
def real_bar(self) -> int:
|
|
|
|
|
"""No abstract nonsense"""
|
2024-06-18 01:27:13 +00:00
|
|
|
pass
|
|
|
|
|
'''
|
2024-06-21 23:43:43 +00:00
|
|
|
original_code = '''class Fu():
|
|
|
|
|
def foo(self) -> dict[str, str]:
|
|
|
|
|
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar())}
|
|
|
|
|
return payload
|
2024-06-18 01:27:13 +00:00
|
|
|
|
2024-06-21 23:43:43 +00:00
|
|
|
def real_bar(self) -> int:
|
|
|
|
|
"""No abstract nonsense"""
|
|
|
|
|
return 0
|
|
|
|
|
'''
|
|
|
|
|
expected_code = '''class Fu():
|
|
|
|
|
def foo(self) -> dict[str, str]:
|
|
|
|
|
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar() + 1)}
|
2024-06-18 01:27:13 +00:00
|
|
|
return payload
|
2024-06-21 23:43:43 +00:00
|
|
|
|
|
|
|
|
def real_bar(self) -> int:
|
|
|
|
|
"""No abstract nonsense"""
|
|
|
|
|
return 0
|
2024-06-18 01:27:13 +00:00
|
|
|
'''
|
|
|
|
|
|
2024-06-21 23:43:43 +00:00
|
|
|
function_name: str = "Fu.foo"
|
2025-03-14 01:52:11 +00:00
|
|
|
parents = (FunctionParent("Fu", "ClassDef"),)
|
2025-04-30 01:34:40 +00:00
|
|
|
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = {("foo", parents), ("real_bar", parents)}
|
2024-06-21 23:43:43 +00:00
|
|
|
new_code: str = replace_functions_in_file(
|
2024-06-18 01:27:13 +00:00
|
|
|
source_code=original_code,
|
2024-06-21 23:43:43 +00:00
|
|
|
original_function_names=[function_name],
|
2024-06-18 01:27:13 +00:00
|
|
|
optimized_code=optim_code,
|
2024-07-09 23:32:23 +00:00
|
|
|
preexisting_objects=preexisting_objects,
|
2024-06-18 01:27:13 +00:00
|
|
|
)
|
2024-06-21 23:43:43 +00:00
|
|
|
assert new_code == expected_code
|
2024-06-23 02:39:15 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_code_replacement12() -> None:
|
|
|
|
|
optim_code = '''class Fu():
|
|
|
|
|
def foo(self) -> dict[str, str]:
|
|
|
|
|
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar() + 1)}
|
|
|
|
|
return payload
|
|
|
|
|
|
|
|
|
|
def real_bar(self) -> int:
|
|
|
|
|
"""No abstract nonsense"""
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
original_code = '''class Fu():
|
|
|
|
|
def foo(self) -> dict[str, str]:
|
|
|
|
|
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar())}
|
|
|
|
|
return payload
|
|
|
|
|
|
|
|
|
|
def real_bar(self) -> int:
|
|
|
|
|
"""No abstract nonsense"""
|
|
|
|
|
return 0
|
|
|
|
|
'''
|
|
|
|
|
expected_code = '''class Fu():
|
|
|
|
|
def foo(self) -> dict[str, str]:
|
|
|
|
|
payload: dict[str, str] = {"bar": self.bar(), "real_bar": str(self.real_bar())}
|
|
|
|
|
return payload
|
|
|
|
|
|
|
|
|
|
def real_bar(self) -> int:
|
|
|
|
|
"""No abstract nonsense"""
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
|
2025-04-30 01:34:40 +00:00
|
|
|
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = []
|
2024-06-23 02:39:15 +00:00
|
|
|
new_code: str = replace_functions_in_file(
|
|
|
|
|
source_code=original_code,
|
|
|
|
|
original_function_names=["Fu.real_bar"],
|
|
|
|
|
optimized_code=optim_code,
|
2024-07-09 23:32:23 +00:00
|
|
|
preexisting_objects=preexisting_objects,
|
2024-06-23 02:39:15 +00:00
|
|
|
)
|
|
|
|
|
assert new_code == expected_code
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_test_libcst_code_replacement13() -> None:
|
|
|
|
|
# Test if the dunder method is not modified
|
|
|
|
|
optim_code = """class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
self.new_attribute = "Sorry i modified a dunder method"
|
|
|
|
|
def new_function(self, value):
|
|
|
|
|
return other_function(self.name)
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return value
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return self.new_attribute
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
original_code = """class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
2025-01-14 01:01:52 +00:00
|
|
|
self.new_attribute = "Sorry i modified a dunder method"
|
2024-06-23 02:39:15 +00:00
|
|
|
def new_function(self, value):
|
|
|
|
|
return other_function(self.name)
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return value
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return self.name
|
|
|
|
|
"""
|
|
|
|
|
|
2025-01-08 22:56:53 +00:00
|
|
|
function_names: list[str] = ["yet_another_function", "other_function"]
|
2025-04-30 01:34:40 +00:00
|
|
|
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = []
|
2024-06-23 02:39:15 +00:00
|
|
|
new_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code,
|
|
|
|
|
function_names=function_names,
|
|
|
|
|
optimized_code=optim_code,
|
2024-10-12 22:29:15 +00:00
|
|
|
module_abspath=Path(__file__).resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
preexisting_objects=preexisting_objects,
|
2024-10-12 22:29:15 +00:00
|
|
|
project_root_path=Path(__file__).resolve().parent.resolve(),
|
2024-06-23 02:39:15 +00:00
|
|
|
)
|
|
|
|
|
assert new_code == original_code
|
2024-07-09 23:32:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_different_class_code_replacement():
|
|
|
|
|
original_code = """from __future__ import annotations
|
|
|
|
|
import sys
|
|
|
|
|
from codeflash.verification.comparator import comparator
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from typing import Iterator
|
|
|
|
|
|
|
|
|
|
class TestType(Enum):
|
|
|
|
|
EXISTING_UNIT_TEST = 1
|
|
|
|
|
INSPIRED_REGRESSION = 2
|
|
|
|
|
GENERATED_REGRESSION = 3
|
|
|
|
|
REPLAY_TEST = 4
|
|
|
|
|
|
|
|
|
|
def to_name(self) -> str:
|
|
|
|
|
names = {
|
|
|
|
|
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
|
|
|
|
|
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
|
|
|
|
|
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
|
|
|
|
|
TestType.REPLAY_TEST: "⏪ Replay Tests",
|
|
|
|
|
}
|
|
|
|
|
return names[self]
|
|
|
|
|
|
|
|
|
|
class TestResults(BaseModel):
|
|
|
|
|
def __iter__(self) -> Iterator[FunctionTestInvocation]:
|
|
|
|
|
return iter(self.test_results)
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
|
return len(self.test_results)
|
|
|
|
|
def __getitem__(self, index: int) -> FunctionTestInvocation:
|
|
|
|
|
return self.test_results[index]
|
|
|
|
|
def __setitem__(self, index: int, value: FunctionTestInvocation) -> None:
|
|
|
|
|
self.test_results[index] = value
|
|
|
|
|
def __delitem__(self, index: int) -> None:
|
|
|
|
|
del self.test_results[index]
|
|
|
|
|
def __contains__(self, value: FunctionTestInvocation) -> bool:
|
|
|
|
|
return value in self.test_results
|
|
|
|
|
def __bool__(self) -> bool:
|
|
|
|
|
return bool(self.test_results)
|
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
|
|
|
# Unordered comparison
|
|
|
|
|
if type(self) != type(other):
|
|
|
|
|
return False
|
|
|
|
|
if len(self) != len(other):
|
|
|
|
|
return False
|
|
|
|
|
original_recursion_limit = sys.getrecursionlimit()
|
|
|
|
|
for test_result in self:
|
|
|
|
|
other_test_result = other.get_by_id(test_result.id)
|
|
|
|
|
if other_test_result is None:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
if original_recursion_limit < 5000:
|
|
|
|
|
sys.setrecursionlimit(5000)
|
|
|
|
|
if (
|
|
|
|
|
test_result.file_name != other_test_result.file_name
|
|
|
|
|
or test_result.did_pass != other_test_result.did_pass
|
|
|
|
|
or test_result.runtime != other_test_result.runtime
|
|
|
|
|
or test_result.test_framework != other_test_result.test_framework
|
|
|
|
|
or test_result.test_type != other_test_result.test_type
|
|
|
|
|
or not comparator(
|
|
|
|
|
test_result.return_value,
|
|
|
|
|
other_test_result.return_value,
|
|
|
|
|
)
|
|
|
|
|
):
|
|
|
|
|
sys.setrecursionlimit(original_recursion_limit)
|
|
|
|
|
return False
|
|
|
|
|
sys.setrecursionlimit(original_recursion_limit)
|
|
|
|
|
return True
|
|
|
|
|
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
|
|
|
|
|
report = {}
|
|
|
|
|
for test_type in TestType:
|
|
|
|
|
report[test_type] = {"passed": 0, "failed": 0}
|
|
|
|
|
for test_result in self.test_results:
|
|
|
|
|
if test_result.test_type != TestType.EXISTING_UNIT_TEST or test_result.id.function_getting_tested:
|
|
|
|
|
if test_result.did_pass:
|
|
|
|
|
report[test_result.test_type]["passed"] += 1
|
|
|
|
|
else:
|
|
|
|
|
report[test_result.test_type]["failed"] += 1
|
|
|
|
|
return report"""
|
|
|
|
|
optim_code = """from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import sys
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from typing import Iterator
|
|
|
|
|
|
|
|
|
|
from codeflash.verification.comparator import comparator
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestType(Enum):
|
|
|
|
|
EXISTING_UNIT_TEST = 1
|
|
|
|
|
INSPIRED_REGRESSION = 2
|
|
|
|
|
GENERATED_REGRESSION = 3
|
|
|
|
|
REPLAY_TEST = 4
|
|
|
|
|
|
|
|
|
|
def to_name(self) -> str:
|
|
|
|
|
if self == TestType.EXISTING_UNIT_TEST:
|
|
|
|
|
return "⚙️ Existing Unit Tests"
|
|
|
|
|
elif self == TestType.INSPIRED_REGRESSION:
|
|
|
|
|
return "🎨 Inspired Regression Tests"
|
|
|
|
|
elif self == TestType.GENERATED_REGRESSION:
|
|
|
|
|
return "🌀 Generated Regression Tests"
|
|
|
|
|
elif self == TestType.REPLAY_TEST:
|
|
|
|
|
return "⏪ Replay Tests"
|
|
|
|
|
|
|
|
|
|
class TestResults(BaseModel):
|
|
|
|
|
def __iter__(self) -> Iterator[FunctionTestInvocation]:
|
|
|
|
|
return iter(self.test_results)
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
def __len__(self) -> int:
|
|
|
|
|
return len(self.test_results)
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
def __getitem__(self, index: int) -> FunctionTestInvocation:
|
|
|
|
|
return self.test_results[index]
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
def __setitem__(self, index: int, value: FunctionTestInvocation) -> None:
|
|
|
|
|
self.test_results[index] = value
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
def __delitem__(self, index: int) -> None:
|
|
|
|
|
del self.test_results[index]
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
def __contains__(self, value: FunctionTestInvocation) -> bool:
|
|
|
|
|
return value in self.test_results
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
def __bool__(self) -> bool:
|
|
|
|
|
return bool(self.test_results)
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
def __eq__(self, other: object) -> bool:
|
|
|
|
|
# Unordered comparison
|
|
|
|
|
if not isinstance(other, TestResults) or len(self) != len(other):
|
|
|
|
|
return False
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
# Increase recursion limit only if necessary
|
|
|
|
|
original_recursion_limit = sys.getrecursionlimit()
|
|
|
|
|
if original_recursion_limit < 5000:
|
|
|
|
|
sys.setrecursionlimit(5000)
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
for test_result in self:
|
|
|
|
|
other_test_result = other.get_by_id(test_result.id)
|
|
|
|
|
if other_test_result is None or not (
|
|
|
|
|
test_result.file_name == other_test_result.file_name and
|
|
|
|
|
test_result.did_pass == other_test_result.did_pass and
|
|
|
|
|
test_result.runtime == other_test_result.runtime and
|
|
|
|
|
test_result.test_framework == other_test_result.test_framework and
|
|
|
|
|
test_result.test_type == other_test_result.test_type and
|
|
|
|
|
comparator(test_result.return_value, other_test_result.return_value)
|
|
|
|
|
):
|
|
|
|
|
sys.setrecursionlimit(original_recursion_limit)
|
|
|
|
|
return False
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
sys.setrecursionlimit(original_recursion_limit)
|
|
|
|
|
return True
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
|
|
|
|
|
report = {test_type: {"passed": 0, "failed": 0} for test_type in TestType}
|
|
|
|
|
for test_result in self.test_results:
|
|
|
|
|
if test_result.test_type != TestType.EXISTING_UNIT_TEST or test_result.id.function_getting_tested:
|
|
|
|
|
key = "passed" if test_result.did_pass else "failed"
|
|
|
|
|
report[test_result.test_type][key] += 1
|
|
|
|
|
return report"""
|
|
|
|
|
|
2025-01-08 22:56:53 +00:00
|
|
|
preexisting_objects = find_preexisting_objects(original_code)
|
2024-07-09 23:32:23 +00:00
|
|
|
|
|
|
|
|
helper_functions = [
|
|
|
|
|
FakeFunctionSource(
|
2024-10-12 22:29:15 +00:00
|
|
|
file_path=Path(
|
2024-10-25 22:45:44 +00:00
|
|
|
"/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/codeflash/verification/test_results.py"
|
2024-10-12 22:29:15 +00:00
|
|
|
),
|
2024-07-09 23:32:23 +00:00
|
|
|
qualified_name="TestType",
|
|
|
|
|
fully_qualified_name="codeflash.verification.test_results.TestType",
|
|
|
|
|
only_function_name="TestType",
|
|
|
|
|
source_code="",
|
|
|
|
|
jedi_definition=JediDefinition(type="class"),
|
2024-10-25 22:45:44 +00:00
|
|
|
)
|
2024-07-09 23:32:23 +00:00
|
|
|
]
|
|
|
|
|
|
|
|
|
|
new_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code,
|
|
|
|
|
function_names=["TestResults.get_test_pass_fail_report_by_type"],
|
|
|
|
|
optimized_code=optim_code,
|
2024-10-12 22:29:15 +00:00
|
|
|
module_abspath=Path(__file__).resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
preexisting_objects=preexisting_objects,
|
2024-10-12 22:29:15 +00:00
|
|
|
project_root_path=Path(__file__).parent.resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
)
|
2024-07-10 00:22:30 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
helper_functions_by_module_abspath = defaultdict(set)
|
|
|
|
|
for helper_function in helper_functions:
|
|
|
|
|
if helper_function.jedi_definition.type != "class":
|
2024-10-25 22:45:44 +00:00
|
|
|
helper_functions_by_module_abspath[helper_function.file_path].add(helper_function.qualified_name)
|
|
|
|
|
for module_abspath, qualified_names in helper_functions_by_module_abspath.items():
|
2024-07-09 23:32:23 +00:00
|
|
|
new_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=new_code,
|
|
|
|
|
function_names=list(qualified_names),
|
|
|
|
|
optimized_code=optim_code,
|
|
|
|
|
module_abspath=module_abspath,
|
|
|
|
|
preexisting_objects=preexisting_objects,
|
2024-10-12 22:29:15 +00:00
|
|
|
project_root_path=Path(__file__).parent.resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
)
|
|
|
|
|
|
2024-07-10 00:22:30 +00:00
|
|
|
assert (
|
2025-04-30 01:34:40 +00:00
|
|
|
new_code
|
|
|
|
|
== """from __future__ import annotations
|
2024-07-10 00:22:30 +00:00
|
|
|
import sys
|
|
|
|
|
from codeflash.verification.comparator import comparator
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from typing import Iterator
|
|
|
|
|
|
|
|
|
|
class TestType(Enum):
|
|
|
|
|
EXISTING_UNIT_TEST = 1
|
|
|
|
|
INSPIRED_REGRESSION = 2
|
|
|
|
|
GENERATED_REGRESSION = 3
|
|
|
|
|
REPLAY_TEST = 4
|
|
|
|
|
|
|
|
|
|
def to_name(self) -> str:
|
|
|
|
|
names = {
|
|
|
|
|
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
|
|
|
|
|
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
|
|
|
|
|
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
|
|
|
|
|
TestType.REPLAY_TEST: "⏪ Replay Tests",
|
|
|
|
|
}
|
|
|
|
|
return names[self]
|
|
|
|
|
|
|
|
|
|
class TestResults(BaseModel):
|
|
|
|
|
def __iter__(self) -> Iterator[FunctionTestInvocation]:
|
|
|
|
|
return iter(self.test_results)
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
|
return len(self.test_results)
|
|
|
|
|
def __getitem__(self, index: int) -> FunctionTestInvocation:
|
|
|
|
|
return self.test_results[index]
|
|
|
|
|
def __setitem__(self, index: int, value: FunctionTestInvocation) -> None:
|
|
|
|
|
self.test_results[index] = value
|
|
|
|
|
def __delitem__(self, index: int) -> None:
|
|
|
|
|
del self.test_results[index]
|
|
|
|
|
def __contains__(self, value: FunctionTestInvocation) -> bool:
|
|
|
|
|
return value in self.test_results
|
|
|
|
|
def __bool__(self) -> bool:
|
|
|
|
|
return bool(self.test_results)
|
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
|
|
|
# Unordered comparison
|
|
|
|
|
if type(self) != type(other):
|
|
|
|
|
return False
|
|
|
|
|
if len(self) != len(other):
|
|
|
|
|
return False
|
|
|
|
|
original_recursion_limit = sys.getrecursionlimit()
|
|
|
|
|
for test_result in self:
|
|
|
|
|
other_test_result = other.get_by_id(test_result.id)
|
|
|
|
|
if other_test_result is None:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
if original_recursion_limit < 5000:
|
|
|
|
|
sys.setrecursionlimit(5000)
|
|
|
|
|
if (
|
|
|
|
|
test_result.file_name != other_test_result.file_name
|
|
|
|
|
or test_result.did_pass != other_test_result.did_pass
|
|
|
|
|
or test_result.runtime != other_test_result.runtime
|
|
|
|
|
or test_result.test_framework != other_test_result.test_framework
|
|
|
|
|
or test_result.test_type != other_test_result.test_type
|
|
|
|
|
or not comparator(
|
|
|
|
|
test_result.return_value,
|
|
|
|
|
other_test_result.return_value,
|
|
|
|
|
)
|
|
|
|
|
):
|
|
|
|
|
sys.setrecursionlimit(original_recursion_limit)
|
|
|
|
|
return False
|
|
|
|
|
sys.setrecursionlimit(original_recursion_limit)
|
|
|
|
|
return True
|
|
|
|
|
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
|
|
|
|
|
report = {test_type: {"passed": 0, "failed": 0} for test_type in TestType}
|
|
|
|
|
for test_result in self.test_results:
|
|
|
|
|
if test_result.test_type != TestType.EXISTING_UNIT_TEST or test_result.id.function_getting_tested:
|
|
|
|
|
key = "passed" if test_result.did_pass else "failed"
|
|
|
|
|
report[test_result.test_type][key] += 1
|
|
|
|
|
return report"""
|
|
|
|
|
)
|
2024-07-09 23:32:23 +00:00
|
|
|
|
|
|
|
|
|
2024-07-11 12:28:56 +00:00
|
|
|
def test_code_replacement_type_annotation() -> None:
|
2024-07-09 23:32:23 +00:00
|
|
|
original_code = '''import numpy as np
|
|
|
|
|
from pydantic.dataclasses import dataclass
|
|
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
|
@dataclass(config=dict(arbitrary_types_allowed=True))
|
|
|
|
|
class Matrix:
|
|
|
|
|
data: Union[List[List[float]], List[np.ndarray], np.ndarray]
|
|
|
|
|
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
|
|
|
|
"""Row-wise cosine similarity between two equal-width matrices."""
|
|
|
|
|
if len(X.data) == 0 or len(Y.data) == 0:
|
|
|
|
|
return np.array([])
|
|
|
|
|
X = np.array(X.data)
|
|
|
|
|
Y = np.array(Y.data)
|
|
|
|
|
if X.shape[1] != Y.shape[1]:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
|
|
|
|
f"and Y has shape {Y.shape}.",
|
|
|
|
|
)
|
|
|
|
|
X_norm = np.linalg.norm(X, axis=1)
|
|
|
|
|
Y_norm = np.linalg.norm(Y, axis=1)
|
|
|
|
|
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
|
|
|
|
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
|
|
|
|
return similarity
|
|
|
|
|
def cosine_similarity_top_k(
|
|
|
|
|
X: Matrix,
|
|
|
|
|
Y: Matrix,
|
|
|
|
|
top_k: Optional[int] = 5,
|
|
|
|
|
score_threshold: Optional[float] = None,
|
|
|
|
|
) -> Tuple[List[Tuple[int, int]], List[float]]:
|
|
|
|
|
"""Row-wise cosine similarity with optional top-k and score threshold filtering.
|
|
|
|
|
Args:
|
|
|
|
|
----
|
|
|
|
|
X: Matrix.
|
|
|
|
|
Y: Matrix, same width as X.
|
|
|
|
|
top_k: Max number of results to return.
|
|
|
|
|
score_threshold: Minimum cosine similarity of results.
|
|
|
|
|
Returns:
|
|
|
|
|
-------
|
|
|
|
|
Tuple of two lists. First contains two-tuples of indices (X_idx, Y_idx),
|
|
|
|
|
second contains corresponding cosine similarities.
|
|
|
|
|
"""
|
|
|
|
|
if len(X.data) == 0 or len(Y.data) == 0:
|
|
|
|
|
return [], []
|
|
|
|
|
score_array = cosine_similarity(X, Y)
|
|
|
|
|
sorted_idxs = score_array.flatten().argsort()[::-1]
|
|
|
|
|
top_k = top_k or len(sorted_idxs)
|
|
|
|
|
top_idxs = sorted_idxs[:top_k]
|
|
|
|
|
score_threshold = score_threshold or -1.0
|
|
|
|
|
top_idxs = top_idxs[score_array.flatten()[top_idxs] > score_threshold]
|
|
|
|
|
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in top_idxs]
|
|
|
|
|
scores = score_array.flatten()[top_idxs].tolist()
|
|
|
|
|
return ret_idxs, scores
|
|
|
|
|
'''
|
|
|
|
|
optim_code = '''from typing import List, Optional, Tuple, Union
|
|
|
|
|
import numpy as np
|
|
|
|
|
from pydantic.dataclasses import dataclass
|
|
|
|
|
@dataclass(config=dict(arbitrary_types_allowed=True))
|
|
|
|
|
class Matrix:
|
|
|
|
|
data: Union[list[list[float]], List[np.ndarray], np.ndarray]
|
|
|
|
|
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
|
|
|
|
"""Row-wise cosine similarity between two equal-width matrices."""
|
|
|
|
|
if len(X.data) == 0 or len(Y.data) == 0:
|
|
|
|
|
return np.array([])
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
X_np, Y_np = np.asarray(X.data), np.asarray(Y.data)
|
|
|
|
|
if X_np.shape[1] != Y_np.shape[1]:
|
|
|
|
|
raise ValueError(f"Number of columns in X and Y must be the same. X has shape {X_np.shape} and Y has shape {Y_np.shape}.")
|
|
|
|
|
X_norm = np.linalg.norm(X_np, axis=1, keepdims=True)
|
|
|
|
|
Y_norm = np.linalg.norm(Y_np, axis=1, keepdims=True)
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
norm_product = X_norm * Y_norm.T
|
|
|
|
|
norm_product[norm_product == 0] = np.inf # Prevent division by zero
|
|
|
|
|
dot_product = np.dot(X_np, Y_np.T)
|
|
|
|
|
similarity = dot_product / norm_product
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
# Any NaN or Inf values are set to 0.0
|
|
|
|
|
np.nan_to_num(similarity, copy=False)
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
return similarity
|
|
|
|
|
def cosine_similarity_top_k(
|
|
|
|
|
X: Matrix,
|
|
|
|
|
Y: Matrix,
|
|
|
|
|
top_k: Optional[int] = 5,
|
|
|
|
|
score_threshold: Optional[float] = None,
|
|
|
|
|
) -> Tuple[List[Tuple[int, int]], List[float]]:
|
|
|
|
|
"""Row-wise cosine similarity with optional top-k and score threshold filtering."""
|
|
|
|
|
if len(X.data) == 0 or len(Y.data) == 0:
|
|
|
|
|
return [], []
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
score_array = cosine_similarity(X, Y)
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
sorted_idxs = np.argpartition(-score_array.flatten(), range(top_k or len(score_array.flatten())))[:(top_k or len(score_array.flatten()))]
|
|
|
|
|
sorted_idxs = sorted_idxs[score_array.flatten()[sorted_idxs] > (score_threshold if score_threshold is not None else -1)]
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in sorted_idxs]
|
|
|
|
|
scores = score_array.flatten()[sorted_idxs].tolist()
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
return ret_idxs, scores
|
|
|
|
|
'''
|
2025-04-30 01:34:40 +00:00
|
|
|
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
2024-07-09 23:32:23 +00:00
|
|
|
|
|
|
|
|
helper_functions = [
|
|
|
|
|
FakeFunctionSource(
|
2024-10-12 22:29:15 +00:00
|
|
|
file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
qualified_name="Matrix",
|
|
|
|
|
fully_qualified_name="code_to_optimize.math_utils.Matrix",
|
|
|
|
|
only_function_name="Matrix",
|
|
|
|
|
source_code="",
|
|
|
|
|
jedi_definition=JediDefinition(type="class"),
|
|
|
|
|
),
|
|
|
|
|
FakeFunctionSource(
|
2024-10-12 22:29:15 +00:00
|
|
|
file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
qualified_name="cosine_similarity",
|
|
|
|
|
fully_qualified_name="code_to_optimize.math_utils.cosine_similarity",
|
|
|
|
|
only_function_name="cosine_similarity",
|
|
|
|
|
source_code="",
|
|
|
|
|
jedi_definition=JediDefinition(type="function"),
|
|
|
|
|
),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
new_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code,
|
|
|
|
|
function_names=["cosine_similarity_top_k"],
|
|
|
|
|
optimized_code=optim_code,
|
2024-10-12 22:29:15 +00:00
|
|
|
module_abspath=(Path(__file__).parent / "code_to_optimize").resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
preexisting_objects=preexisting_objects,
|
2024-10-12 22:29:15 +00:00
|
|
|
project_root_path=Path(__file__).parent.parent.resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
)
|
|
|
|
|
assert (
|
2025-04-30 01:34:40 +00:00
|
|
|
new_code
|
|
|
|
|
== '''import numpy as np
|
2024-07-09 23:32:23 +00:00
|
|
|
from pydantic.dataclasses import dataclass
|
|
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
|
@dataclass(config=dict(arbitrary_types_allowed=True))
|
|
|
|
|
class Matrix:
|
|
|
|
|
data: Union[List[List[float]], List[np.ndarray], np.ndarray]
|
|
|
|
|
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
|
|
|
|
"""Row-wise cosine similarity between two equal-width matrices."""
|
|
|
|
|
if len(X.data) == 0 or len(Y.data) == 0:
|
|
|
|
|
return np.array([])
|
|
|
|
|
X = np.array(X.data)
|
|
|
|
|
Y = np.array(Y.data)
|
|
|
|
|
if X.shape[1] != Y.shape[1]:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Number of columns in X and Y must be the same. X has shape {X.shape} "
|
|
|
|
|
f"and Y has shape {Y.shape}.",
|
|
|
|
|
)
|
|
|
|
|
X_norm = np.linalg.norm(X, axis=1)
|
|
|
|
|
Y_norm = np.linalg.norm(Y, axis=1)
|
|
|
|
|
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
|
|
|
|
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
|
|
|
|
return similarity
|
|
|
|
|
def cosine_similarity_top_k(
|
|
|
|
|
X: Matrix,
|
|
|
|
|
Y: Matrix,
|
|
|
|
|
top_k: Optional[int] = 5,
|
|
|
|
|
score_threshold: Optional[float] = None,
|
|
|
|
|
) -> Tuple[List[Tuple[int, int]], List[float]]:
|
|
|
|
|
"""Row-wise cosine similarity with optional top-k and score threshold filtering."""
|
|
|
|
|
if len(X.data) == 0 or len(Y.data) == 0:
|
|
|
|
|
return [], []
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
score_array = cosine_similarity(X, Y)
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
sorted_idxs = np.argpartition(-score_array.flatten(), range(top_k or len(score_array.flatten())))[:(top_k or len(score_array.flatten()))]
|
|
|
|
|
sorted_idxs = sorted_idxs[score_array.flatten()[sorted_idxs] > (score_threshold if score_threshold is not None else -1)]
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in sorted_idxs]
|
|
|
|
|
scores = score_array.flatten()[sorted_idxs].tolist()
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
return ret_idxs, scores
|
|
|
|
|
'''
|
|
|
|
|
)
|
|
|
|
|
helper_functions_by_module_abspath = defaultdict(set)
|
|
|
|
|
for helper_function in helper_functions:
|
|
|
|
|
if helper_function.jedi_definition.type != "class":
|
2024-10-25 22:45:44 +00:00
|
|
|
helper_functions_by_module_abspath[helper_function.file_path].add(helper_function.qualified_name)
|
|
|
|
|
for module_abspath, qualified_names in helper_functions_by_module_abspath.items():
|
2024-07-09 23:32:23 +00:00
|
|
|
new_helper_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=new_code,
|
|
|
|
|
function_names=list(qualified_names),
|
|
|
|
|
optimized_code=optim_code,
|
|
|
|
|
module_abspath=module_abspath,
|
|
|
|
|
preexisting_objects=preexisting_objects,
|
2024-10-12 22:29:15 +00:00
|
|
|
project_root_path=Path(__file__).parent.parent.resolve(),
|
2024-07-09 23:32:23 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert (
|
2025-04-30 01:34:40 +00:00
|
|
|
new_helper_code
|
|
|
|
|
== '''import numpy as np
|
2024-07-09 23:32:23 +00:00
|
|
|
from pydantic.dataclasses import dataclass
|
|
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
|
@dataclass(config=dict(arbitrary_types_allowed=True))
|
|
|
|
|
class Matrix:
|
|
|
|
|
data: Union[List[List[float]], List[np.ndarray], np.ndarray]
|
|
|
|
|
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
|
|
|
|
"""Row-wise cosine similarity between two equal-width matrices."""
|
|
|
|
|
if len(X.data) == 0 or len(Y.data) == 0:
|
|
|
|
|
return np.array([])
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
X_np, Y_np = np.asarray(X.data), np.asarray(Y.data)
|
|
|
|
|
if X_np.shape[1] != Y_np.shape[1]:
|
|
|
|
|
raise ValueError(f"Number of columns in X and Y must be the same. X has shape {X_np.shape} and Y has shape {Y_np.shape}.")
|
|
|
|
|
X_norm = np.linalg.norm(X_np, axis=1, keepdims=True)
|
|
|
|
|
Y_norm = np.linalg.norm(Y_np, axis=1, keepdims=True)
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
norm_product = X_norm * Y_norm.T
|
|
|
|
|
norm_product[norm_product == 0] = np.inf # Prevent division by zero
|
|
|
|
|
dot_product = np.dot(X_np, Y_np.T)
|
|
|
|
|
similarity = dot_product / norm_product
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
# Any NaN or Inf values are set to 0.0
|
|
|
|
|
np.nan_to_num(similarity, copy=False)
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
return similarity
|
|
|
|
|
def cosine_similarity_top_k(
|
|
|
|
|
X: Matrix,
|
|
|
|
|
Y: Matrix,
|
|
|
|
|
top_k: Optional[int] = 5,
|
|
|
|
|
score_threshold: Optional[float] = None,
|
|
|
|
|
) -> Tuple[List[Tuple[int, int]], List[float]]:
|
|
|
|
|
"""Row-wise cosine similarity with optional top-k and score threshold filtering."""
|
|
|
|
|
if len(X.data) == 0 or len(Y.data) == 0:
|
|
|
|
|
return [], []
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
score_array = cosine_similarity(X, Y)
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
sorted_idxs = np.argpartition(-score_array.flatten(), range(top_k or len(score_array.flatten())))[:(top_k or len(score_array.flatten()))]
|
|
|
|
|
sorted_idxs = sorted_idxs[score_array.flatten()[sorted_idxs] > (score_threshold if score_threshold is not None else -1)]
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in sorted_idxs]
|
|
|
|
|
scores = score_array.flatten()[sorted_idxs].tolist()
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-09 23:32:23 +00:00
|
|
|
return ret_idxs, scores
|
|
|
|
|
'''
|
|
|
|
|
)
|
2024-07-11 12:28:56 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_future_aliased_imports_removal() -> None:
|
|
|
|
|
module_code1 = """from __future__ import annotations as _annotations
|
|
|
|
|
print("Hello monde")
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
expected_code1 = """print("Hello monde")
|
|
|
|
|
"""
|
|
|
|
|
|
2024-07-29 12:01:49 +00:00
|
|
|
assert delete___future___aliased_imports(module_code1) == expected_code1
|
2024-07-11 12:28:56 +00:00
|
|
|
|
|
|
|
|
module_code2 = """from __future__ import annotations
|
|
|
|
|
print("Hello monde")
|
|
|
|
|
"""
|
|
|
|
|
|
2024-07-29 12:01:49 +00:00
|
|
|
assert delete___future___aliased_imports(module_code2) == module_code2
|
2024-07-11 12:28:56 +00:00
|
|
|
|
|
|
|
|
module_code3 = """from __future__ import annotations as _annotations
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
from past import autopasta as dood
|
|
|
|
|
print("Hello monde")
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
expected_code3 = """from __future__ import annotations
|
|
|
|
|
from past import autopasta as dood
|
|
|
|
|
print("Hello monde")
|
|
|
|
|
"""
|
|
|
|
|
|
2024-07-29 12:01:49 +00:00
|
|
|
assert delete___future___aliased_imports(module_code3) == expected_code3
|
2024-07-11 12:28:56 +00:00
|
|
|
|
|
|
|
|
module_code4 = """from __future__ import annotations
|
|
|
|
|
from __future__ import annotations as _annotations
|
|
|
|
|
from past import autopasta as dood
|
|
|
|
|
print("Hello monde")
|
|
|
|
|
"""
|
|
|
|
|
|
2024-07-26 11:46:55 +00:00
|
|
|
expected_module_code4 = """from __future__ import annotations
|
|
|
|
|
from past import autopasta as dood
|
|
|
|
|
print("Hello monde")
|
|
|
|
|
"""
|
|
|
|
|
|
2024-07-29 12:01:49 +00:00
|
|
|
assert delete___future___aliased_imports(module_code4) == expected_module_code4
|
2024-07-11 12:28:56 +00:00
|
|
|
|
|
|
|
|
module_code5 = """from future import annotations as _annotations
|
|
|
|
|
from past import autopasta as dood
|
|
|
|
|
print("Hello monde")
|
|
|
|
|
"""
|
|
|
|
|
|
2024-07-29 12:01:49 +00:00
|
|
|
assert delete___future___aliased_imports(module_code5) == module_code5
|
2024-07-26 11:46:55 +00:00
|
|
|
|
|
|
|
|
module_code6 = '''"""Private logic for creating models."""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations as _annotations
|
|
|
|
|
'''
|
|
|
|
|
expected_code6 = '''"""Private logic for creating models."""
|
|
|
|
|
'''
|
|
|
|
|
|
2024-07-29 12:01:49 +00:00
|
|
|
assert delete___future___aliased_imports(module_code6) == expected_code6
|
2024-07-26 11:46:55 +00:00
|
|
|
|
2024-07-11 12:49:50 +00:00
|
|
|
|
2024-07-10 03:38:36 +00:00
|
|
|
def test_0_diff_code_replacement():
|
|
|
|
|
original_code = """from __future__ import annotations
|
2025-02-13 08:10:53 +00:00
|
|
|
|
2024-07-10 03:38:36 +00:00
|
|
|
import numpy as np
|
|
|
|
|
def functionA():
|
|
|
|
|
return np.array([1, 2, 3])
|
|
|
|
|
"""
|
|
|
|
|
optim_code_a = """from __future__ import annotations
|
|
|
|
|
import numpy as np
|
|
|
|
|
def functionA():
|
|
|
|
|
return np.array([1, 2, 3])"""
|
|
|
|
|
|
|
|
|
|
assert is_zero_diff(original_code, optim_code_a)
|
|
|
|
|
|
|
|
|
|
optim_code_b = """
|
|
|
|
|
import numpy as np
|
|
|
|
|
def functionA():
|
|
|
|
|
return np.array([1, 2, 3])"""
|
|
|
|
|
|
|
|
|
|
assert is_zero_diff(original_code, optim_code_b)
|
|
|
|
|
|
|
|
|
|
optim_code_c = """
|
|
|
|
|
def functionA():
|
|
|
|
|
return np.array([1, 2, 3])"""
|
|
|
|
|
|
|
|
|
|
assert is_zero_diff(original_code, optim_code_c)
|
|
|
|
|
|
|
|
|
|
optim_code_d = """from __future__ import annotations
|
2024-10-17 05:40:34 +00:00
|
|
|
|
2024-07-10 03:38:36 +00:00
|
|
|
import numpy as np
|
|
|
|
|
def functionA():
|
|
|
|
|
return np.array([1, 2, 3, 4])
|
|
|
|
|
"""
|
|
|
|
|
assert not is_zero_diff(original_code, optim_code_d)
|
2024-10-17 05:40:34 +00:00
|
|
|
|
|
|
|
|
optim_code_e = '''"""
|
|
|
|
|
Zis a Docstring?
|
|
|
|
|
"""
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import ast
|
|
|
|
|
def functionA():
|
|
|
|
|
"""
|
|
|
|
|
Und Zis?
|
|
|
|
|
"""
|
|
|
|
|
import numpy as np
|
|
|
|
|
return np.array([1, 2, 3])
|
|
|
|
|
'''
|
|
|
|
|
assert is_zero_diff(original_code, optim_code_e)
|
2025-01-08 22:56:53 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_nested_class() -> None:
|
|
|
|
|
optim_code = """import libcst as cst
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = str(name)
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return self.name
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, int)
|
|
|
|
|
|
|
|
|
|
class NestedClass:
|
|
|
|
|
def nested_function(self):
|
|
|
|
|
return "I am nested and modified"
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
original_code = """class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
|
|
|
|
|
class NestedClass:
|
|
|
|
|
def nested_function(self):
|
|
|
|
|
return "I am nested"
|
|
|
|
|
|
|
|
|
|
print("Hello world")
|
|
|
|
|
"""
|
|
|
|
|
expected = """import libcst as cst
|
|
|
|
|
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
2025-01-23 17:10:32 +00:00
|
|
|
self.name = str(name)
|
2025-01-08 22:56:53 +00:00
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, int)
|
|
|
|
|
|
|
|
|
|
class NestedClass:
|
|
|
|
|
def nested_function(self):
|
|
|
|
|
return "I am nested"
|
|
|
|
|
|
|
|
|
|
print("Hello world")
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
function_names: list[str] = [
|
|
|
|
|
"NewClass.new_function2",
|
|
|
|
|
"NestedClass.nested_function",
|
|
|
|
|
] # Nested classes should be ignored, even if provided as target
|
2025-04-30 01:34:40 +00:00
|
|
|
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
2025-01-08 22:56:53 +00:00
|
|
|
new_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code,
|
|
|
|
|
function_names=function_names,
|
|
|
|
|
optimized_code=optim_code,
|
|
|
|
|
module_abspath=Path(__file__).resolve(),
|
|
|
|
|
preexisting_objects=preexisting_objects,
|
|
|
|
|
project_root_path=Path(__file__).resolve().parent.resolve(),
|
|
|
|
|
)
|
|
|
|
|
assert new_code == expected
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_modify_back_to_original() -> None:
|
|
|
|
|
optim_code = """class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
|
|
|
|
|
print("Hello world")
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
original_code = """class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
|
2025-05-01 01:14:00 +00:00
|
|
|
print("Hello world")
|
|
|
|
|
"""
|
2025-01-08 22:56:53 +00:00
|
|
|
function_names: list[str] = ["NewClass.__init__", "NewClass.__call__", "NewClass.new_function2"]
|
2025-04-30 01:34:40 +00:00
|
|
|
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
2025-01-08 22:56:53 +00:00
|
|
|
new_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code,
|
|
|
|
|
function_names=function_names,
|
|
|
|
|
optimized_code=optim_code,
|
|
|
|
|
module_abspath=Path(__file__).resolve(),
|
|
|
|
|
preexisting_objects=preexisting_objects,
|
|
|
|
|
project_root_path=Path(__file__).resolve().parent.resolve(),
|
|
|
|
|
)
|
2025-05-01 02:20:13 +00:00
|
|
|
assert new_code == original_code
|
2025-04-30 23:32:43 +00:00
|
|
|
|
2025-06-06 20:19:39 +00:00
|
|
|
|
2025-04-30 23:32:43 +00:00
|
|
|
def test_global_reassignment() -> None:
|
2025-07-25 12:39:47 +00:00
|
|
|
root_dir = Path(__file__).parent.parent.resolve()
|
|
|
|
|
code_path = (root_dir / "code_to_optimize/global_var_original.py").resolve()
|
|
|
|
|
|
2025-04-30 23:32:43 +00:00
|
|
|
original_code = """a=1
|
|
|
|
|
print("Hello world")
|
2025-05-01 02:46:29 +00:00
|
|
|
def some_fn():
|
|
|
|
|
print("did noting")
|
2025-04-30 23:32:43 +00:00
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
2025-05-01 02:46:29 +00:00
|
|
|
"""
|
2025-08-06 00:33:46 +00:00
|
|
|
optimized_code = f"""```python:{code_path.relative_to(root_dir)}
|
2025-07-25 12:39:47 +00:00
|
|
|
import numpy as np
|
|
|
|
|
|
2025-05-01 02:46:29 +00:00
|
|
|
def some_fn():
|
|
|
|
|
a=np.zeros(10)
|
|
|
|
|
print("did something")
|
2025-04-30 23:32:43 +00:00
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
2025-05-01 02:46:29 +00:00
|
|
|
return "I am still old"
|
2025-04-30 23:32:43 +00:00
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
a=2
|
|
|
|
|
print("Hello world")
|
2025-08-06 00:33:46 +00:00
|
|
|
```
|
|
|
|
|
"""
|
2025-05-01 02:46:29 +00:00
|
|
|
expected_code = """import numpy as np
|
|
|
|
|
|
2025-04-30 23:32:43 +00:00
|
|
|
a=2
|
|
|
|
|
print("Hello world")
|
2025-05-01 22:25:50 +00:00
|
|
|
def some_fn():
|
|
|
|
|
a=np.zeros(10)
|
|
|
|
|
print("did something")
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
2025-06-28 21:44:02 +00:00
|
|
|
return cst.ensure_type(value, str)"""
|
2025-05-01 22:25:50 +00:00
|
|
|
code_path.write_text(original_code, encoding="utf-8")
|
|
|
|
|
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
|
|
|
|
|
project_root_path = (Path(__file__).parent / "..").resolve()
|
|
|
|
|
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
|
|
|
|
|
test_config = TestConfig(
|
|
|
|
|
tests_root=tests_root,
|
|
|
|
|
tests_project_rootdir=project_root_path,
|
|
|
|
|
project_root_path=project_root_path,
|
|
|
|
|
test_framework="pytest",
|
|
|
|
|
pytest_cmd="pytest",
|
|
|
|
|
)
|
|
|
|
|
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
|
|
|
|
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
|
|
|
|
original_helper_code: dict[Path, str] = {}
|
|
|
|
|
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
|
|
|
|
for helper_function_path in helper_function_paths:
|
|
|
|
|
with helper_function_path.open(encoding="utf8") as f:
|
|
|
|
|
helper_code = f.read()
|
|
|
|
|
original_helper_code[helper_function_path] = helper_code
|
|
|
|
|
func_optimizer.args = Args()
|
|
|
|
|
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
2025-08-05 22:09:42 +00:00
|
|
|
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code
|
2025-05-01 22:25:50 +00:00
|
|
|
)
|
|
|
|
|
new_code = code_path.read_text(encoding="utf-8")
|
|
|
|
|
code_path.unlink(missing_ok=True)
|
|
|
|
|
assert new_code.rstrip() == expected_code.rstrip()
|
|
|
|
|
|
|
|
|
|
original_code = """print("Hello world")
|
|
|
|
|
def some_fn():
|
|
|
|
|
print("did noting")
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
a=1
|
|
|
|
|
"""
|
2025-08-06 00:33:46 +00:00
|
|
|
optimized_code = f"""```python:{code_path.relative_to(root_dir)}
|
2025-07-25 12:39:47 +00:00
|
|
|
a=2
|
2025-05-01 22:25:50 +00:00
|
|
|
import numpy as np
|
|
|
|
|
def some_fn():
|
|
|
|
|
a=np.zeros(10)
|
|
|
|
|
print("did something")
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
print("Hello world")
|
2025-08-06 00:33:46 +00:00
|
|
|
```
|
|
|
|
|
"""
|
2025-05-01 22:25:50 +00:00
|
|
|
expected_code = """import numpy as np
|
|
|
|
|
|
|
|
|
|
print("Hello world")
|
|
|
|
|
def some_fn():
|
|
|
|
|
a=np.zeros(10)
|
|
|
|
|
print("did something")
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
a=2
|
|
|
|
|
"""
|
|
|
|
|
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
|
|
|
|
|
code_path.write_text(original_code, encoding="utf-8")
|
|
|
|
|
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
|
|
|
|
|
project_root_path = (Path(__file__).parent / "..").resolve()
|
|
|
|
|
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
|
|
|
|
|
test_config = TestConfig(
|
|
|
|
|
tests_root=tests_root,
|
|
|
|
|
tests_project_rootdir=project_root_path,
|
|
|
|
|
project_root_path=project_root_path,
|
|
|
|
|
test_framework="pytest",
|
|
|
|
|
pytest_cmd="pytest",
|
|
|
|
|
)
|
|
|
|
|
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
|
|
|
|
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
|
|
|
|
original_helper_code: dict[Path, str] = {}
|
|
|
|
|
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
|
|
|
|
for helper_function_path in helper_function_paths:
|
|
|
|
|
with helper_function_path.open(encoding="utf8") as f:
|
|
|
|
|
helper_code = f.read()
|
|
|
|
|
original_helper_code[helper_function_path] = helper_code
|
|
|
|
|
func_optimizer.args = Args()
|
|
|
|
|
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
2025-08-05 22:09:42 +00:00
|
|
|
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code
|
2025-05-01 22:25:50 +00:00
|
|
|
)
|
|
|
|
|
new_code = code_path.read_text(encoding="utf-8")
|
|
|
|
|
code_path.unlink(missing_ok=True)
|
|
|
|
|
assert new_code.rstrip() == expected_code.rstrip()
|
|
|
|
|
|
|
|
|
|
original_code = """a=1
|
|
|
|
|
print("Hello world")
|
|
|
|
|
def some_fn():
|
|
|
|
|
print("did noting")
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
"""
|
2025-08-06 00:33:46 +00:00
|
|
|
optimized_code = f"""```python:{code_path.relative_to(root_dir)}
|
2025-07-25 12:39:47 +00:00
|
|
|
import numpy as np
|
2025-05-01 22:25:50 +00:00
|
|
|
a=2
|
|
|
|
|
def some_fn():
|
|
|
|
|
a=np.zeros(10)
|
|
|
|
|
print("did something")
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
a=3
|
|
|
|
|
print("Hello world")
|
2025-08-06 00:33:46 +00:00
|
|
|
```
|
|
|
|
|
"""
|
2025-05-01 22:25:50 +00:00
|
|
|
expected_code = """import numpy as np
|
|
|
|
|
|
|
|
|
|
a=3
|
|
|
|
|
print("Hello world")
|
2025-05-01 02:46:29 +00:00
|
|
|
def some_fn():
|
|
|
|
|
a=np.zeros(10)
|
|
|
|
|
print("did something")
|
2025-04-30 23:32:43 +00:00
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
2025-05-01 02:46:29 +00:00
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
2025-04-30 23:32:43 +00:00
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
"""
|
2025-05-01 02:46:29 +00:00
|
|
|
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
|
|
|
|
|
code_path.write_text(original_code, encoding="utf-8")
|
|
|
|
|
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
|
|
|
|
|
project_root_path = (Path(__file__).parent / "..").resolve()
|
|
|
|
|
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
|
|
|
|
|
test_config = TestConfig(
|
|
|
|
|
tests_root=tests_root,
|
|
|
|
|
tests_project_rootdir=project_root_path,
|
|
|
|
|
project_root_path=project_root_path,
|
|
|
|
|
test_framework="pytest",
|
|
|
|
|
pytest_cmd="pytest",
|
2025-04-30 23:32:43 +00:00
|
|
|
)
|
2025-05-01 02:46:29 +00:00
|
|
|
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
|
|
|
|
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
|
|
|
|
original_helper_code: dict[Path, str] = {}
|
|
|
|
|
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
|
|
|
|
for helper_function_path in helper_function_paths:
|
|
|
|
|
with helper_function_path.open(encoding="utf8") as f:
|
|
|
|
|
helper_code = f.read()
|
|
|
|
|
original_helper_code[helper_function_path] = helper_code
|
|
|
|
|
func_optimizer.args = Args()
|
|
|
|
|
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
2025-08-05 22:09:42 +00:00
|
|
|
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code
|
2025-05-01 02:46:29 +00:00
|
|
|
)
|
|
|
|
|
new_code = code_path.read_text(encoding="utf-8")
|
|
|
|
|
code_path.unlink(missing_ok=True)
|
|
|
|
|
assert new_code.rstrip() == expected_code.rstrip()
|
2025-05-01 22:30:41 +00:00
|
|
|
|
|
|
|
|
original_code = """a=1
|
|
|
|
|
print("Hello world")
|
|
|
|
|
def some_fn():
|
|
|
|
|
print("did noting")
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
"""
|
2025-08-06 00:33:46 +00:00
|
|
|
optimized_code = f"""```python:{code_path.relative_to(root_dir)}
|
2025-07-25 12:39:47 +00:00
|
|
|
a=2
|
2025-05-01 22:30:41 +00:00
|
|
|
import numpy as np
|
|
|
|
|
def some_fn():
|
|
|
|
|
a=np.zeros(10)
|
|
|
|
|
print("did something")
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
print("Hello world")
|
2025-08-06 00:33:46 +00:00
|
|
|
```
|
|
|
|
|
"""
|
2025-05-01 22:30:41 +00:00
|
|
|
expected_code = """import numpy as np
|
|
|
|
|
|
|
|
|
|
a=2
|
|
|
|
|
print("Hello world")
|
|
|
|
|
def some_fn():
|
|
|
|
|
a=np.zeros(10)
|
|
|
|
|
print("did something")
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
"""
|
|
|
|
|
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
|
|
|
|
|
code_path.write_text(original_code, encoding="utf-8")
|
|
|
|
|
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
|
|
|
|
|
project_root_path = (Path(__file__).parent / "..").resolve()
|
|
|
|
|
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
|
|
|
|
|
test_config = TestConfig(
|
|
|
|
|
tests_root=tests_root,
|
|
|
|
|
tests_project_rootdir=project_root_path,
|
|
|
|
|
project_root_path=project_root_path,
|
|
|
|
|
test_framework="pytest",
|
|
|
|
|
pytest_cmd="pytest",
|
|
|
|
|
)
|
|
|
|
|
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
|
|
|
|
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
|
|
|
|
original_helper_code: dict[Path, str] = {}
|
|
|
|
|
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
|
|
|
|
for helper_function_path in helper_function_paths:
|
|
|
|
|
with helper_function_path.open(encoding="utf8") as f:
|
|
|
|
|
helper_code = f.read()
|
|
|
|
|
original_helper_code[helper_function_path] = helper_code
|
|
|
|
|
func_optimizer.args = Args()
|
|
|
|
|
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
2025-08-05 22:09:42 +00:00
|
|
|
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code
|
2025-05-01 22:30:41 +00:00
|
|
|
)
|
|
|
|
|
new_code = code_path.read_text(encoding="utf-8")
|
|
|
|
|
code_path.unlink(missing_ok=True)
|
|
|
|
|
assert new_code.rstrip() == expected_code.rstrip()
|
|
|
|
|
|
|
|
|
|
original_code = """a=1
|
|
|
|
|
print("Hello world")
|
|
|
|
|
def some_fn():
|
|
|
|
|
print("did noting")
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
"""
|
2025-08-06 00:33:46 +00:00
|
|
|
optimized_code = f"""```python:{code_path.relative_to(root_dir)}
|
2025-07-25 12:39:47 +00:00
|
|
|
import numpy as np
|
2025-05-01 22:30:41 +00:00
|
|
|
a=2
|
|
|
|
|
def some_fn():
|
|
|
|
|
a=np.zeros(10)
|
|
|
|
|
print("did something")
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
a=3
|
|
|
|
|
print("Hello world")
|
2025-08-06 00:33:46 +00:00
|
|
|
```
|
|
|
|
|
"""
|
2025-05-01 22:30:41 +00:00
|
|
|
expected_code = """import numpy as np
|
|
|
|
|
|
|
|
|
|
a=3
|
|
|
|
|
print("Hello world")
|
|
|
|
|
def some_fn():
|
|
|
|
|
a=np.zeros(10)
|
|
|
|
|
print("did something")
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
2025-05-01 23:02:20 +00:00
|
|
|
"""
|
|
|
|
|
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
|
|
|
|
|
code_path.write_text(original_code, encoding="utf-8")
|
|
|
|
|
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
|
|
|
|
|
project_root_path = (Path(__file__).parent / "..").resolve()
|
|
|
|
|
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
|
|
|
|
|
test_config = TestConfig(
|
|
|
|
|
tests_root=tests_root,
|
|
|
|
|
tests_project_rootdir=project_root_path,
|
|
|
|
|
project_root_path=project_root_path,
|
|
|
|
|
test_framework="pytest",
|
|
|
|
|
pytest_cmd="pytest",
|
|
|
|
|
)
|
|
|
|
|
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
|
|
|
|
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
|
|
|
|
original_helper_code: dict[Path, str] = {}
|
|
|
|
|
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
|
|
|
|
for helper_function_path in helper_function_paths:
|
|
|
|
|
with helper_function_path.open(encoding="utf8") as f:
|
|
|
|
|
helper_code = f.read()
|
|
|
|
|
original_helper_code[helper_function_path] = helper_code
|
|
|
|
|
func_optimizer.args = Args()
|
|
|
|
|
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
2025-08-05 22:09:42 +00:00
|
|
|
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code
|
2025-05-01 23:02:20 +00:00
|
|
|
)
|
|
|
|
|
new_code = code_path.read_text(encoding="utf-8")
|
|
|
|
|
code_path.unlink(missing_ok=True)
|
|
|
|
|
assert new_code.rstrip() == expected_code.rstrip()
|
|
|
|
|
|
|
|
|
|
original_code = """if 2<3:
|
|
|
|
|
a=4
|
|
|
|
|
else:
|
|
|
|
|
a=5
|
|
|
|
|
print("Hello world")
|
|
|
|
|
def some_fn():
|
|
|
|
|
print("did noting")
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
"""
|
2025-08-06 00:33:46 +00:00
|
|
|
optimized_code = f"""```python:{code_path.relative_to(root_dir)}
|
2025-07-25 12:39:47 +00:00
|
|
|
import numpy as np
|
2025-05-01 23:02:20 +00:00
|
|
|
if 1<2:
|
|
|
|
|
a=2
|
|
|
|
|
else:
|
|
|
|
|
a=3
|
|
|
|
|
a = 6
|
|
|
|
|
def some_fn():
|
|
|
|
|
a=np.zeros(10)
|
|
|
|
|
print("did something")
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
print("Hello world")
|
2025-08-06 00:33:46 +00:00
|
|
|
```
|
2025-05-01 23:02:20 +00:00
|
|
|
"""
|
|
|
|
|
expected_code = """import numpy as np
|
|
|
|
|
|
2025-08-22 12:52:15 +00:00
|
|
|
a = 6
|
|
|
|
|
|
2025-05-01 23:02:20 +00:00
|
|
|
if 2<3:
|
|
|
|
|
a=4
|
|
|
|
|
else:
|
|
|
|
|
a=5
|
|
|
|
|
print("Hello world")
|
|
|
|
|
def some_fn():
|
|
|
|
|
a=np.zeros(10)
|
|
|
|
|
print("did something")
|
|
|
|
|
class NewClass:
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
|
|
|
|
def __init__(self, name):
|
|
|
|
|
self.name = name
|
|
|
|
|
def __call__(self, value):
|
|
|
|
|
return "I am still old"
|
|
|
|
|
def new_function2(value):
|
|
|
|
|
return cst.ensure_type(value, str)
|
2025-05-01 22:30:41 +00:00
|
|
|
"""
|
|
|
|
|
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
|
|
|
|
|
code_path.write_text(original_code, encoding="utf-8")
|
|
|
|
|
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
|
|
|
|
|
project_root_path = (Path(__file__).parent / "..").resolve()
|
|
|
|
|
func = FunctionToOptimize(function_name="some_fn", parents=[], file_path=code_path)
|
|
|
|
|
test_config = TestConfig(
|
|
|
|
|
tests_root=tests_root,
|
|
|
|
|
tests_project_rootdir=project_root_path,
|
|
|
|
|
project_root_path=project_root_path,
|
|
|
|
|
test_framework="pytest",
|
|
|
|
|
pytest_cmd="pytest",
|
|
|
|
|
)
|
|
|
|
|
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
|
|
|
|
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
|
|
|
|
original_helper_code: dict[Path, str] = {}
|
|
|
|
|
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
|
|
|
|
for helper_function_path in helper_function_paths:
|
|
|
|
|
with helper_function_path.open(encoding="utf8") as f:
|
|
|
|
|
helper_code = f.read()
|
|
|
|
|
original_helper_code[helper_function_path] = helper_code
|
|
|
|
|
func_optimizer.args = Args()
|
|
|
|
|
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
2025-08-05 22:09:42 +00:00
|
|
|
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code
|
2025-05-01 22:30:41 +00:00
|
|
|
)
|
|
|
|
|
new_code = code_path.read_text(encoding="utf-8")
|
|
|
|
|
code_path.unlink(missing_ok=True)
|
2025-06-06 19:30:30 +00:00
|
|
|
assert new_code.rstrip() == expected_code.rstrip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAutouseFixtureModifier:
|
|
|
|
|
"""Test cases for AutouseFixtureModifier class."""
|
|
|
|
|
|
|
|
|
|
def test_modifies_autouse_fixture_with_pytest_decorator(self):
|
|
|
|
|
"""Test that autouse fixture with @pytest.fixture is modified correctly."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def my_fixture(request):
|
|
|
|
|
print("setup")
|
|
|
|
|
yield
|
|
|
|
|
print("teardown")
|
|
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def my_fixture(request):
|
|
|
|
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
|
|
|
|
yield
|
|
|
|
|
else:
|
|
|
|
|
print("setup")
|
|
|
|
|
yield
|
|
|
|
|
print("teardown")
|
|
|
|
|
'''
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
modifier = AutouseFixtureModifier()
|
|
|
|
|
modified_module = module.visit(modifier)
|
|
|
|
|
|
|
|
|
|
# Parse expected to normalize formatting
|
|
|
|
|
expected_module = cst.parse_module(expected_code)
|
|
|
|
|
assert modified_module.code.strip() == expected_module.code.strip()
|
|
|
|
|
|
|
|
|
|
def test_modifies_autouse_fixture_with_fixture_decorator(self):
|
|
|
|
|
"""Test that autouse fixture with @fixture is modified correctly."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
from pytest import fixture
|
|
|
|
|
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(request):
|
|
|
|
|
setup_code()
|
|
|
|
|
yield "value"
|
|
|
|
|
cleanup_code()
|
2025-06-06 20:11:07 +00:00
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
from pytest import fixture
|
|
|
|
|
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(request):
|
|
|
|
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
|
|
|
|
yield
|
|
|
|
|
else:
|
|
|
|
|
setup_code()
|
|
|
|
|
yield "value"
|
|
|
|
|
cleanup_code()
|
2025-06-06 19:30:30 +00:00
|
|
|
'''
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
modifier = AutouseFixtureModifier()
|
|
|
|
|
modified_module = module.visit(modifier)
|
|
|
|
|
|
|
|
|
|
# Check that the if statement was added
|
2025-06-06 20:11:07 +00:00
|
|
|
assert modified_module.code.strip() == expected_code.strip()
|
2025-06-06 19:30:30 +00:00
|
|
|
|
|
|
|
|
def test_ignores_non_autouse_fixture(self):
|
|
|
|
|
"""Test that non-autouse fixtures are not modified."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def my_fixture(request):
|
|
|
|
|
return "test_value"
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
|
|
|
def session_fixture():
|
|
|
|
|
return "session_value"
|
|
|
|
|
'''
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
modifier = AutouseFixtureModifier()
|
|
|
|
|
modified_module = module.visit(modifier)
|
|
|
|
|
|
|
|
|
|
# Code should remain unchanged
|
|
|
|
|
assert modified_module.code == source_code
|
|
|
|
|
|
|
|
|
|
def test_ignores_regular_functions(self):
|
|
|
|
|
"""Test that regular functions are not modified."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
def regular_function():
|
|
|
|
|
return "not a fixture"
|
|
|
|
|
|
|
|
|
|
@some_other_decorator
|
|
|
|
|
def decorated_function():
|
|
|
|
|
return "also not a fixture"
|
|
|
|
|
'''
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
modifier = AutouseFixtureModifier()
|
|
|
|
|
modified_module = module.visit(modifier)
|
|
|
|
|
|
|
|
|
|
# Code should remain unchanged
|
|
|
|
|
assert modified_module.code == source_code
|
|
|
|
|
|
|
|
|
|
def test_handles_multiple_autouse_fixtures(self):
|
|
|
|
|
"""Test that multiple autouse fixtures in the same file are all modified."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def fixture_one(request):
|
|
|
|
|
yield "one"
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def fixture_two(request):
|
|
|
|
|
yield "two"
|
2025-06-06 20:11:07 +00:00
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def fixture_one(request):
|
|
|
|
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
|
|
|
|
yield
|
|
|
|
|
else:
|
|
|
|
|
yield "one"
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def fixture_two(request):
|
|
|
|
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
|
|
|
|
yield
|
|
|
|
|
else:
|
|
|
|
|
yield "two"
|
2025-06-06 19:30:30 +00:00
|
|
|
'''
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
modifier = AutouseFixtureModifier()
|
|
|
|
|
modified_module = module.visit(modifier)
|
|
|
|
|
|
|
|
|
|
# Both fixtures should be modified
|
|
|
|
|
code = modified_module.code
|
2025-06-06 20:11:07 +00:00
|
|
|
assert code==expected_code
|
2025-06-06 19:30:30 +00:00
|
|
|
|
|
|
|
|
def test_preserves_fixture_with_complex_body(self):
|
|
|
|
|
"""Test that fixtures with complex bodies are handled correctly."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def complex_fixture(request):
|
2025-06-06 20:41:34 +00:00
|
|
|
try:
|
|
|
|
|
setup_database()
|
|
|
|
|
configure_logging()
|
|
|
|
|
yield get_test_client()
|
|
|
|
|
finally:
|
|
|
|
|
cleanup_database()
|
|
|
|
|
reset_logging()
|
2025-06-06 20:11:07 +00:00
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def complex_fixture(request):
|
|
|
|
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
|
|
|
|
yield
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
setup_database()
|
|
|
|
|
configure_logging()
|
|
|
|
|
yield get_test_client()
|
|
|
|
|
finally:
|
|
|
|
|
cleanup_database()
|
|
|
|
|
reset_logging()
|
2025-06-06 19:30:30 +00:00
|
|
|
'''
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
modifier = AutouseFixtureModifier()
|
|
|
|
|
modified_module = module.visit(modifier)
|
|
|
|
|
|
|
|
|
|
code = modified_module.code
|
2025-06-06 20:41:34 +00:00
|
|
|
assert code.rstrip()==expected_code.rstrip()
|
2025-06-06 19:30:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestPytestMarkAdder:
|
|
|
|
|
"""Test cases for PytestMarkAdder class."""
|
|
|
|
|
|
|
|
|
|
def test_adds_pytest_import_when_missing(self):
|
|
|
|
|
"""Test that pytest import is added when not present."""
|
|
|
|
|
source_code = '''
|
2025-06-06 20:11:07 +00:00
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
2025-06-06 19:30:30 +00:00
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
|
|
|
|
modified_module = module.visit(mark_adder)
|
|
|
|
|
|
|
|
|
|
code = modified_module.code
|
2025-06-06 20:11:07 +00:00
|
|
|
assert code==expected_code
|
2025-06-06 19:30:30 +00:00
|
|
|
|
|
|
|
|
def test_skips_pytest_import_when_present(self):
|
|
|
|
|
"""Test that pytest import is not duplicated when already present."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
2025-06-06 20:11:07 +00:00
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
2025-06-06 19:30:30 +00:00
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
|
|
|
|
modified_module = module.visit(mark_adder)
|
|
|
|
|
|
|
|
|
|
code = modified_module.code
|
|
|
|
|
# Should only have one import pytest line
|
2025-06-06 20:11:07 +00:00
|
|
|
assert code==expected_code
|
2025-06-06 19:30:30 +00:00
|
|
|
|
|
|
|
|
def test_handles_from_pytest_import(self):
|
|
|
|
|
"""Test that existing 'from pytest import ...' is recognized."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
from pytest import fixture
|
|
|
|
|
|
|
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
2025-06-06 20:11:07 +00:00
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
from pytest import fixture
|
|
|
|
|
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
|
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
2025-06-06 19:30:30 +00:00
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
|
|
|
|
modified_module = module.visit(mark_adder)
|
|
|
|
|
|
|
|
|
|
code = modified_module.code
|
|
|
|
|
# Should not add import pytest since pytest is already imported
|
2025-06-06 20:11:07 +00:00
|
|
|
assert code.strip()==expected_code.strip()
|
2025-06-06 19:30:30 +00:00
|
|
|
|
|
|
|
|
def test_adds_mark_to_all_functions(self):
|
|
|
|
|
"""Test that marks are added to all functions in the module."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
def test_first():
|
|
|
|
|
assert True
|
|
|
|
|
|
|
|
|
|
def test_second():
|
|
|
|
|
assert False
|
|
|
|
|
|
2025-06-06 20:11:07 +00:00
|
|
|
def helper_function():
|
|
|
|
|
return "not a test"
|
|
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
|
|
|
|
def test_first():
|
|
|
|
|
assert True
|
|
|
|
|
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
|
|
|
|
def test_second():
|
|
|
|
|
assert False
|
|
|
|
|
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
2025-06-06 19:30:30 +00:00
|
|
|
def helper_function():
|
|
|
|
|
return "not a test"
|
|
|
|
|
'''
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
|
|
|
|
modified_module = module.visit(mark_adder)
|
|
|
|
|
|
|
|
|
|
code = modified_module.code
|
|
|
|
|
# All functions should get the mark
|
2025-06-06 20:11:07 +00:00
|
|
|
assert code==expected_code
|
2025-06-06 19:30:30 +00:00
|
|
|
|
|
|
|
|
def test_skips_existing_mark(self):
|
|
|
|
|
"""Test that existing marks are not duplicated."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
|
|
|
|
def test_already_marked():
|
|
|
|
|
assert True
|
|
|
|
|
|
2025-06-06 20:11:07 +00:00
|
|
|
def test_needs_mark():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
|
|
|
|
def test_already_marked():
|
|
|
|
|
assert True
|
|
|
|
|
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
2025-06-06 19:30:30 +00:00
|
|
|
def test_needs_mark():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
|
|
|
|
modified_module = module.visit(mark_adder)
|
|
|
|
|
|
|
|
|
|
code = modified_module.code
|
|
|
|
|
# Should have exactly 2 marks total (one existing, one added)
|
2025-06-06 20:11:07 +00:00
|
|
|
assert code==expected_code
|
2025-06-06 19:30:30 +00:00
|
|
|
|
|
|
|
|
def test_handles_different_mark_names(self):
|
|
|
|
|
"""Test that different mark names work correctly."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
2025-06-06 20:11:07 +00:00
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.mark.slow
|
2025-06-06 19:30:30 +00:00
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
mark_adder = PytestMarkAdder("slow")
|
|
|
|
|
modified_module = module.visit(mark_adder)
|
|
|
|
|
|
|
|
|
|
code = modified_module.code
|
2025-06-06 20:11:07 +00:00
|
|
|
assert code==expected_code
|
2025-06-06 19:30:30 +00:00
|
|
|
|
|
|
|
|
def test_preserves_existing_decorators(self):
|
|
|
|
|
"""Test that existing decorators are preserved."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("value", [1, 2, 3])
|
|
|
|
|
@pytest.fixture
|
2025-06-06 20:11:07 +00:00
|
|
|
def test_with_decorators():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("value", [1, 2, 3])
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
2025-06-06 19:30:30 +00:00
|
|
|
def test_with_decorators():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
|
|
|
|
modified_module = module.visit(mark_adder)
|
|
|
|
|
|
|
|
|
|
code = modified_module.code
|
2025-06-06 20:11:07 +00:00
|
|
|
assert code==expected_code
|
2025-06-06 19:30:30 +00:00
|
|
|
|
|
|
|
|
def test_handles_call_style_existing_marks(self):
|
|
|
|
|
"""Test recognition of existing marks in call style (with parentheses)."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.mark.codeflash_no_autouse()
|
|
|
|
|
def test_with_call_mark():
|
|
|
|
|
assert True
|
|
|
|
|
|
2025-06-06 20:11:07 +00:00
|
|
|
def test_needs_mark():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.mark.codeflash_no_autouse()
|
|
|
|
|
def test_with_call_mark():
|
|
|
|
|
assert True
|
|
|
|
|
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
2025-06-06 19:30:30 +00:00
|
|
|
def test_needs_mark():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
|
|
|
|
modified_module = module.visit(mark_adder)
|
|
|
|
|
|
|
|
|
|
code = modified_module.code
|
|
|
|
|
# Should recognize the existing call-style mark and not duplicate
|
2025-06-06 20:11:07 +00:00
|
|
|
assert code==expected_code
|
2025-06-06 19:30:30 +00:00
|
|
|
|
|
|
|
|
def test_empty_module(self):
|
|
|
|
|
"""Test handling of empty module."""
|
|
|
|
|
source_code = ''
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
|
|
|
|
modified_module = module.visit(mark_adder)
|
|
|
|
|
|
|
|
|
|
# Should just add the import
|
|
|
|
|
code = modified_module.code
|
2025-06-06 20:11:07 +00:00
|
|
|
assert code =='import pytest'
|
2025-06-06 19:30:30 +00:00
|
|
|
|
|
|
|
|
def test_module_with_only_imports(self):
|
|
|
|
|
"""Test handling of module with only imports."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
from pathlib import Path
|
2025-06-06 20:11:07 +00:00
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
from pathlib import Path
|
2025-06-06 19:30:30 +00:00
|
|
|
'''
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
|
|
|
|
modified_module = module.visit(mark_adder)
|
|
|
|
|
|
|
|
|
|
code = modified_module.code
|
2025-06-06 20:11:07 +00:00
|
|
|
assert code==expected_code
|
2025-06-06 19:30:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestIntegration:
|
2025-06-14 00:27:45 +00:00
|
|
|
"""Integration tests for all transformers working together."""
|
2025-06-06 19:30:30 +00:00
|
|
|
|
2025-06-14 00:27:45 +00:00
|
|
|
def test_all_transformers_together(self):
|
|
|
|
|
"""Test that all three transformers can work on the same code."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def my_fixture():
|
|
|
|
|
yield "value"
|
|
|
|
|
|
|
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
|
|
|
|
def my_fixture(request):
|
|
|
|
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
|
|
|
|
yield
|
|
|
|
|
else:
|
|
|
|
|
yield "value"
|
|
|
|
|
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
|
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
# First apply AddRequestArgument
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
request_adder = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(request_adder)
|
|
|
|
|
|
|
|
|
|
# Then apply AutouseFixtureModifier
|
|
|
|
|
autouse_modifier = AutouseFixtureModifier()
|
|
|
|
|
modified_module = modified_module.visit(autouse_modifier)
|
|
|
|
|
|
|
|
|
|
# Finally apply PytestMarkAdder
|
|
|
|
|
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
|
|
|
|
final_module = modified_module.visit(mark_adder)
|
|
|
|
|
|
|
|
|
|
# Compare complete strings
|
|
|
|
|
assert final_module.code == expected_code
|
|
|
|
|
|
|
|
|
|
def test_transformers_with_existing_request_parameter(self):
|
|
|
|
|
"""Test transformers when request parameter already exists."""
|
2025-06-06 19:30:30 +00:00
|
|
|
source_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def my_fixture(request):
|
2025-06-14 00:27:45 +00:00
|
|
|
setup_code()
|
2025-06-06 19:30:30 +00:00
|
|
|
yield "value"
|
2025-06-14 00:27:45 +00:00
|
|
|
cleanup_code()
|
2025-06-06 19:30:30 +00:00
|
|
|
|
2025-06-06 20:11:07 +00:00
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
|
|
|
|
def my_fixture(request):
|
|
|
|
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
|
|
|
|
yield
|
|
|
|
|
else:
|
2025-06-14 00:27:45 +00:00
|
|
|
setup_code()
|
2025-06-06 20:11:07 +00:00
|
|
|
yield "value"
|
2025-06-14 00:27:45 +00:00
|
|
|
cleanup_code()
|
2025-06-06 20:11:07 +00:00
|
|
|
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
2025-06-06 19:30:30 +00:00
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
2025-06-14 00:27:45 +00:00
|
|
|
# Apply all transformers in sequence
|
2025-06-06 19:30:30 +00:00
|
|
|
module = cst.parse_module(source_code)
|
2025-06-14 00:27:45 +00:00
|
|
|
request_adder = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(request_adder)
|
|
|
|
|
|
2025-06-06 19:30:30 +00:00
|
|
|
autouse_modifier = AutouseFixtureModifier()
|
2025-06-14 00:27:45 +00:00
|
|
|
modified_module = modified_module.visit(autouse_modifier)
|
2025-06-06 19:30:30 +00:00
|
|
|
|
|
|
|
|
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
|
|
|
|
final_module = modified_module.visit(mark_adder)
|
|
|
|
|
|
2025-06-14 00:27:45 +00:00
|
|
|
# Compare complete strings
|
|
|
|
|
assert final_module.code == expected_code
|
|
|
|
|
|
|
|
|
|
def test_transformers_with_self_parameter(self):
|
|
|
|
|
"""Test transformers when fixture has self parameter."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def my_fixture(self):
|
|
|
|
|
yield "value"
|
|
|
|
|
|
|
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
|
|
|
|
def my_fixture(self, request):
|
|
|
|
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
|
|
|
|
yield
|
|
|
|
|
else:
|
|
|
|
|
yield "value"
|
|
|
|
|
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
|
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
# Apply all transformers in sequence
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
request_adder = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(request_adder)
|
|
|
|
|
|
|
|
|
|
autouse_modifier = AutouseFixtureModifier()
|
|
|
|
|
modified_module = modified_module.visit(autouse_modifier)
|
|
|
|
|
|
|
|
|
|
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
|
|
|
|
final_module = modified_module.visit(mark_adder)
|
|
|
|
|
|
|
|
|
|
# Compare complete strings
|
|
|
|
|
assert final_module.code == expected_code
|
|
|
|
|
|
|
|
|
|
def test_transformers_with_multiple_fixtures(self):
|
|
|
|
|
"""Test transformers with multiple autouse fixtures."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def fixture_one():
|
|
|
|
|
yield "one"
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def fixture_two(self, param):
|
|
|
|
|
yield "two"
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
def regular_fixture():
|
|
|
|
|
return "regular"
|
|
|
|
|
|
|
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
expected_code = '''
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
|
|
|
|
def fixture_one(request):
|
|
|
|
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
|
|
|
|
yield
|
|
|
|
|
else:
|
|
|
|
|
yield "one"
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
|
|
|
|
def fixture_two(self, request, param):
|
|
|
|
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
|
|
|
|
yield
|
|
|
|
|
else:
|
|
|
|
|
yield "two"
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
|
|
|
|
def regular_fixture():
|
|
|
|
|
return "regular"
|
|
|
|
|
|
|
|
|
|
@pytest.mark.codeflash_no_autouse
|
|
|
|
|
def test_something():
|
|
|
|
|
assert True
|
|
|
|
|
'''
|
|
|
|
|
# Apply all transformers in sequence
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
request_adder = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(request_adder)
|
|
|
|
|
|
|
|
|
|
autouse_modifier = AutouseFixtureModifier()
|
|
|
|
|
modified_module = modified_module.visit(autouse_modifier)
|
|
|
|
|
|
|
|
|
|
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
|
|
|
|
final_module = modified_module.visit(mark_adder)
|
|
|
|
|
|
|
|
|
|
# Compare complete strings
|
|
|
|
|
assert final_module.code == expected_code
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestAddRequestArgument:
|
|
|
|
|
"""Test cases for AddRequestArgument transformer."""
|
|
|
|
|
|
|
|
|
|
def test_adds_request_to_autouse_fixture_no_existing_args(self):
|
|
|
|
|
"""Test adding request argument to autouse fixture with no existing arguments."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture():
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
expected = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(request):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
transformer = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(transformer)
|
|
|
|
|
|
|
|
|
|
assert modified_module.code.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
def test_adds_request_to_pytest_fixture_autouse(self):
|
|
|
|
|
"""Test adding request argument to pytest.fixture with autouse=True."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def my_fixture():
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
expected = '''
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def my_fixture(request):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
transformer = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(transformer)
|
|
|
|
|
|
|
|
|
|
assert modified_module.code.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
def test_adds_request_after_self_parameter(self):
|
|
|
|
|
"""Test adding request argument after self parameter."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(self):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
expected = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(self, request):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
transformer = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(transformer)
|
|
|
|
|
|
|
|
|
|
assert modified_module.code.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
def test_adds_request_after_cls_parameter(self):
|
|
|
|
|
"""Test adding request argument after cls parameter."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(cls):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
expected = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(cls, request):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
transformer = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(transformer)
|
|
|
|
|
|
|
|
|
|
assert modified_module.code.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
def test_adds_request_before_other_parameters(self):
|
|
|
|
|
"""Test adding request argument before other parameters (not self/cls)."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(param1, param2):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
expected = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(request, param1, param2):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
transformer = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(transformer)
|
|
|
|
|
|
|
|
|
|
assert modified_module.code.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
def test_adds_request_after_self_with_other_parameters(self):
|
|
|
|
|
"""Test adding request argument after self with other parameters."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(self, param1, param2):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
expected = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(self, request, param1, param2):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
transformer = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(transformer)
|
|
|
|
|
|
|
|
|
|
assert modified_module.code.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
def test_skips_when_request_already_present(self):
|
|
|
|
|
"""Test that request argument is not added when already present."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(request):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
expected = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(request):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
transformer = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(transformer)
|
|
|
|
|
|
|
|
|
|
assert modified_module.code.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
def test_skips_when_request_present_with_other_args(self):
|
|
|
|
|
"""Test that request argument is not added when already present with other args."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(self, request, param1):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
expected = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(self, request, param1):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
transformer = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(transformer)
|
|
|
|
|
|
|
|
|
|
assert modified_module.code.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
def test_ignores_non_autouse_fixture(self):
|
|
|
|
|
"""Test that non-autouse fixtures are not modified."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
@fixture
|
|
|
|
|
def my_fixture():
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
expected = '''
|
|
|
|
|
@fixture
|
|
|
|
|
def my_fixture():
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
transformer = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(transformer)
|
|
|
|
|
|
|
|
|
|
assert modified_module.code.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
def test_ignores_fixture_with_autouse_false(self):
|
|
|
|
|
"""Test that fixtures with autouse=False are not modified."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
@fixture(autouse=False)
|
|
|
|
|
def my_fixture():
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
expected = '''
|
|
|
|
|
@fixture(autouse=False)
|
|
|
|
|
def my_fixture():
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
transformer = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(transformer)
|
|
|
|
|
|
|
|
|
|
assert modified_module.code.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
def test_ignores_regular_function(self):
|
|
|
|
|
"""Test that regular functions are not modified."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
def my_function():
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
expected = '''
|
|
|
|
|
def my_function():
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
transformer = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(transformer)
|
|
|
|
|
|
|
|
|
|
assert modified_module.code.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
def test_handles_multiple_autouse_fixtures(self):
|
|
|
|
|
"""Test handling multiple autouse fixtures in the same module."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def fixture1():
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def fixture2(self):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def fixture3(request):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
expected = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def fixture1(request):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def fixture2(self, request):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def fixture3(request):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
transformer = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(transformer)
|
|
|
|
|
|
|
|
|
|
assert modified_module.code.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
def test_handles_fixture_with_other_decorators(self):
|
|
|
|
|
"""Test handling fixture with other decorators."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
@some_decorator
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
@another_decorator
|
|
|
|
|
def my_fixture():
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
expected = '''
|
|
|
|
|
@some_decorator
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
@another_decorator
|
|
|
|
|
def my_fixture(request):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
transformer = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(transformer)
|
|
|
|
|
|
|
|
|
|
assert modified_module.code.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
def test_preserves_function_body_and_docstring(self):
|
|
|
|
|
"""Test that function body and docstring are preserved."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture():
|
|
|
|
|
"""This is a docstring."""
|
|
|
|
|
x = 1
|
|
|
|
|
y = 2
|
|
|
|
|
return x + y
|
|
|
|
|
'''
|
|
|
|
|
expected = '''
|
|
|
|
|
@fixture(autouse=True)
|
|
|
|
|
def my_fixture(request):
|
|
|
|
|
"""This is a docstring."""
|
|
|
|
|
x = 1
|
|
|
|
|
y = 2
|
|
|
|
|
return x + y
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
transformer = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(transformer)
|
|
|
|
|
|
|
|
|
|
assert modified_module.code.strip() == expected.strip()
|
|
|
|
|
|
|
|
|
|
def test_handles_fixture_with_additional_arguments(self):
|
|
|
|
|
"""Test handling fixture with additional keyword arguments."""
|
|
|
|
|
source_code = '''
|
|
|
|
|
@fixture(autouse=True, scope="session")
|
|
|
|
|
def my_fixture():
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
expected = '''
|
|
|
|
|
@fixture(autouse=True, scope="session")
|
|
|
|
|
def my_fixture(request):
|
|
|
|
|
pass
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
module = cst.parse_module(source_code)
|
|
|
|
|
transformer = AddRequestArgument()
|
|
|
|
|
modified_module = module.visit(transformer)
|
2025-06-06 20:18:58 +00:00
|
|
|
|
2025-06-14 00:27:45 +00:00
|
|
|
assert modified_module.code.strip() == expected.strip()
|
2025-07-31 13:52:11 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_type_checking_imports():
|
|
|
|
|
optim_code = """from dataclasses import dataclass
|
|
|
|
|
from pydantic_ai.providers import Provider, infer_provider
|
|
|
|
|
from pydantic_ai_slim.pydantic_ai.models import Model
|
|
|
|
|
from pydantic_ai_slim.pydantic_ai.tools import ToolDefinition
|
|
|
|
|
from typing import Literal
|
|
|
|
|
|
|
|
|
|
#### problamatic imports ####
|
|
|
|
|
from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool
|
|
|
|
|
import requests
|
|
|
|
|
import aiohttp as aiohttp_
|
|
|
|
|
from math import pi as PI, sin as sine
|
|
|
|
|
|
|
|
|
|
@dataclass(init=False)
|
|
|
|
|
class HuggingFaceModel(Model):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
model_name: str,
|
|
|
|
|
*,
|
|
|
|
|
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
|
|
|
|
|
):
|
|
|
|
|
print(requests.__name__)
|
|
|
|
|
print(aiohttp_.__name__)
|
|
|
|
|
print(PI)
|
|
|
|
|
print(sine)
|
|
|
|
|
# Fast branch: avoid repeating provider assignment
|
|
|
|
|
if isinstance(provider, str):
|
|
|
|
|
provider_obj = infer_provider(provider)
|
|
|
|
|
else:
|
|
|
|
|
provider_obj = provider
|
|
|
|
|
self._provider = provider
|
|
|
|
|
self._model_name = model_name
|
|
|
|
|
self.client = provider_obj.client
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
|
|
|
|
|
# Inline dict creation and single pass for possible strict attribute
|
|
|
|
|
tool_dict = {
|
|
|
|
|
'type': 'function',
|
|
|
|
|
'function': {
|
|
|
|
|
'name': f.name,
|
|
|
|
|
'description': f.description,
|
|
|
|
|
'parameters': f.parameters_json_schema,
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
if f.strict is not None:
|
|
|
|
|
tool_dict['function']['strict'] = f.strict
|
|
|
|
|
return ChatCompletionInputTool.parse_obj_as_instance(tool_dict) # type: ignore
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
original_code = """from dataclasses import dataclass
|
|
|
|
|
from pydantic_ai.providers import Provider, infer_provider
|
|
|
|
|
from pydantic_ai_slim.pydantic_ai.models import Model
|
|
|
|
|
from pydantic_ai_slim.pydantic_ai.tools import ToolDefinition
|
|
|
|
|
from typing import Literal
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
import aiohttp as aiohttp_
|
|
|
|
|
from math import pi as PI, sin as sine
|
|
|
|
|
from huggingface_hub import (
|
|
|
|
|
AsyncInferenceClient,
|
|
|
|
|
ChatCompletionInputMessage,
|
|
|
|
|
ChatCompletionInputMessageChunk,
|
|
|
|
|
ChatCompletionInputTool,
|
|
|
|
|
ChatCompletionInputToolCall,
|
|
|
|
|
ChatCompletionInputURL,
|
|
|
|
|
ChatCompletionOutput,
|
|
|
|
|
ChatCompletionOutputMessage,
|
|
|
|
|
ChatCompletionStreamOutput,
|
|
|
|
|
)
|
|
|
|
|
from huggingface_hub.errors import HfHubHTTPError
|
|
|
|
|
|
|
|
|
|
except ImportError as _import_error:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
'Please install `huggingface_hub` to use Hugging Face Inference Providers, '
|
|
|
|
|
'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`'
|
|
|
|
|
) from _import_error
|
|
|
|
|
|
|
|
|
|
if True:
|
|
|
|
|
import requests
|
|
|
|
|
|
|
|
|
|
__all__ = (
|
|
|
|
|
'HuggingFaceModel',
|
|
|
|
|
'HuggingFaceModelSettings',
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@dataclass(init=False)
|
|
|
|
|
class HuggingFaceModel(Model):
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
model_name: str,
|
|
|
|
|
*,
|
|
|
|
|
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
|
|
|
|
|
):
|
|
|
|
|
self._model_name = model_name
|
|
|
|
|
self._provider = provider
|
|
|
|
|
if isinstance(provider, str):
|
|
|
|
|
provider = infer_provider(provider)
|
|
|
|
|
self.client = provider.client
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
|
|
|
|
|
tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore
|
|
|
|
|
{
|
|
|
|
|
'type': 'function',
|
|
|
|
|
'function': {
|
|
|
|
|
'name': f.name,
|
|
|
|
|
'description': f.description,
|
|
|
|
|
'parameters': f.parameters_json_schema,
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
if f.strict is not None:
|
|
|
|
|
tool_param['function']['strict'] = f.strict
|
|
|
|
|
return tool_param
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function_name: str = "HuggingFaceModel._map_tool_definition"
|
|
|
|
|
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
|
|
|
|
|
new_code: str = replace_functions_and_add_imports(
|
|
|
|
|
source_code=original_code,
|
|
|
|
|
function_names=[function_name],
|
|
|
|
|
optimized_code=optim_code,
|
|
|
|
|
module_abspath=Path(__file__).resolve(),
|
|
|
|
|
preexisting_objects=preexisting_objects,
|
|
|
|
|
project_root_path=Path(__file__).resolve().parent.resolve(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert not re.search(r"^import requests\b", new_code, re.MULTILINE) # conditional simple import: import <name>
|
|
|
|
|
assert not re.search(r"^import aiohttp as aiohttp_\b", new_code, re.MULTILINE) # conditional alias import: import <name> as <alias>
|
|
|
|
|
assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports
|
2025-07-31 13:54:52 +00:00
|
|
|
assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import
|
2025-08-23 15:34:37 +00:00
|
|
|
|
2025-08-22 14:45:44 +00:00
|
|
|
def test_top_level_global_assignments() -> None:
|
|
|
|
|
root_dir = Path(__file__).parent.parent.resolve()
|
|
|
|
|
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
|
|
|
|
|
|
|
|
|
|
original_code = '''"""
|
|
|
|
|
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
|
|
|
|
|
|
import structlog
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
from skyvern.forge import app
|
|
|
|
|
from skyvern.forge.sdk.prompting import PromptEngine
|
|
|
|
|
from skyvern.webeye.actions.actions import ActionType
|
|
|
|
|
|
|
|
|
|
LOG = structlog.get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
# Initialize prompt engine
|
|
|
|
|
prompt_engine = PromptEngine("skyvern")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def hydrate_input_text_actions_with_field_names(
|
|
|
|
|
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
|
|
|
|
|
) -> Dict[str, List[Dict[str, Any]]]:
|
|
|
|
|
"""
|
|
|
|
|
Add field_name to input_text actions based on generated mappings.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
|
|
|
|
|
field_mappings: Dictionary mapping "task_id:action_id" to field names
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Updated actions_by_task with field_name added to input_text actions
|
|
|
|
|
"""
|
|
|
|
|
updated_actions_by_task = {}
|
|
|
|
|
|
|
|
|
|
for task_id, actions in actions_by_task.items():
|
|
|
|
|
updated_actions = []
|
|
|
|
|
|
|
|
|
|
for action in actions:
|
|
|
|
|
action_copy = action.copy()
|
|
|
|
|
|
|
|
|
|
if action.get("action_type") == ActionType.INPUT_TEXT:
|
|
|
|
|
action_id = action.get("action_id", "")
|
|
|
|
|
mapping_key = f"{task_id}:{action_id}"
|
|
|
|
|
|
|
|
|
|
if mapping_key in field_mappings:
|
|
|
|
|
action_copy["field_name"] = field_mappings[mapping_key]
|
|
|
|
|
else:
|
|
|
|
|
# Fallback field name if mapping not found
|
|
|
|
|
intention = action.get("intention", "")
|
|
|
|
|
if intention:
|
|
|
|
|
# Simple field name generation from intention
|
|
|
|
|
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
|
|
|
|
|
field_name = "".join(c for c in field_name if c.isalnum() or c == "_")
|
|
|
|
|
action_copy["field_name"] = field_name or "unknown_field"
|
|
|
|
|
else:
|
|
|
|
|
action_copy["field_name"] = "unknown_field"
|
|
|
|
|
|
|
|
|
|
updated_actions.append(action_copy)
|
|
|
|
|
|
|
|
|
|
updated_actions_by_task[task_id] = updated_actions
|
|
|
|
|
|
|
|
|
|
return updated_actions_by_task
|
|
|
|
|
'''
|
|
|
|
|
main_file.write_text(original_code, encoding="utf-8")
|
|
|
|
|
optim_code = f'''```python:{main_file.relative_to(root_dir)}
|
|
|
|
|
from skyvern.webeye.actions.actions import ActionType
|
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
# Precompiled regex for efficiently generating simple field_name from intention
|
|
|
|
|
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
|
|
|
|
|
|
|
|
|
|
def hydrate_input_text_actions_with_field_names(
|
|
|
|
|
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
|
|
|
|
|
) -> Dict[str, List[Dict[str, Any]]]:
|
|
|
|
|
"""
|
|
|
|
|
Add field_name to input_text actions based on generated mappings.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
|
|
|
|
|
field_mappings: Dictionary mapping "task_id:action_id" to field names
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Updated actions_by_task with field_name added to input_text actions
|
|
|
|
|
"""
|
|
|
|
|
updated_actions_by_task = {{}}
|
|
|
|
|
|
|
|
|
|
input_text_type = ActionType.INPUT_TEXT # local variable for faster access
|
|
|
|
|
intention_cleanup = _INTENTION_CLEANUP_RE
|
|
|
|
|
|
|
|
|
|
for task_id, actions in actions_by_task.items():
|
|
|
|
|
updated_actions = []
|
|
|
|
|
|
|
|
|
|
for action in actions:
|
|
|
|
|
action_copy = action.copy()
|
|
|
|
|
|
|
|
|
|
if action.get("action_type") == input_text_type:
|
|
|
|
|
action_id = action.get("action_id", "")
|
|
|
|
|
mapping_key = f"{{task_id}}:{{action_id}}"
|
|
|
|
|
|
|
|
|
|
if mapping_key in field_mappings:
|
|
|
|
|
action_copy["field_name"] = field_mappings[mapping_key]
|
|
|
|
|
else:
|
|
|
|
|
# Fallback field name if mapping not found
|
|
|
|
|
intention = action.get("intention", "")
|
|
|
|
|
if intention:
|
|
|
|
|
# Simple field name generation from intention
|
|
|
|
|
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
|
|
|
|
|
# Use compiled regex instead of "".join(c for ...)
|
|
|
|
|
field_name = intention_cleanup.sub("", field_name)
|
|
|
|
|
action_copy["field_name"] = field_name or "unknown_field"
|
|
|
|
|
else:
|
|
|
|
|
action_copy["field_name"] = "unknown_field"
|
|
|
|
|
|
|
|
|
|
updated_actions.append(action_copy)
|
|
|
|
|
|
|
|
|
|
updated_actions_by_task[task_id] = updated_actions
|
|
|
|
|
|
|
|
|
|
return updated_actions_by_task
|
|
|
|
|
```
|
|
|
|
|
'''
|
|
|
|
|
expected = '''"""
|
|
|
|
|
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
|
|
|
|
|
|
import structlog
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
from skyvern.forge import app
|
|
|
|
|
from skyvern.forge.sdk.prompting import PromptEngine
|
|
|
|
|
from skyvern.webeye.actions.actions import ActionType
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
|
|
|
|
|
|
|
|
|
|
LOG = structlog.get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
# Initialize prompt engine
|
|
|
|
|
prompt_engine = PromptEngine("skyvern")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def hydrate_input_text_actions_with_field_names(
|
|
|
|
|
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
|
|
|
|
|
) -> Dict[str, List[Dict[str, Any]]]:
|
|
|
|
|
"""
|
|
|
|
|
Add field_name to input_text actions based on generated mappings.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
|
|
|
|
|
field_mappings: Dictionary mapping "task_id:action_id" to field names
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Updated actions_by_task with field_name added to input_text actions
|
|
|
|
|
"""
|
|
|
|
|
updated_actions_by_task = {}
|
|
|
|
|
|
|
|
|
|
input_text_type = ActionType.INPUT_TEXT # local variable for faster access
|
|
|
|
|
intention_cleanup = _INTENTION_CLEANUP_RE
|
|
|
|
|
|
|
|
|
|
for task_id, actions in actions_by_task.items():
|
|
|
|
|
updated_actions = []
|
|
|
|
|
|
|
|
|
|
for action in actions:
|
|
|
|
|
action_copy = action.copy()
|
|
|
|
|
|
|
|
|
|
if action.get("action_type") == input_text_type:
|
|
|
|
|
action_id = action.get("action_id", "")
|
|
|
|
|
mapping_key = f"{task_id}:{action_id}"
|
|
|
|
|
|
|
|
|
|
if mapping_key in field_mappings:
|
|
|
|
|
action_copy["field_name"] = field_mappings[mapping_key]
|
|
|
|
|
else:
|
|
|
|
|
# Fallback field name if mapping not found
|
|
|
|
|
intention = action.get("intention", "")
|
|
|
|
|
if intention:
|
|
|
|
|
# Simple field name generation from intention
|
|
|
|
|
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
|
|
|
|
|
# Use compiled regex instead of "".join(c for ...)
|
|
|
|
|
field_name = intention_cleanup.sub("", field_name)
|
|
|
|
|
action_copy["field_name"] = field_name or "unknown_field"
|
|
|
|
|
else:
|
|
|
|
|
action_copy["field_name"] = "unknown_field"
|
|
|
|
|
|
|
|
|
|
updated_actions.append(action_copy)
|
|
|
|
|
|
|
|
|
|
updated_actions_by_task[task_id] = updated_actions
|
|
|
|
|
|
|
|
|
|
return updated_actions_by_task
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file)
|
|
|
|
|
test_config = TestConfig(
|
|
|
|
|
tests_root=root_dir / "tests/pytest",
|
|
|
|
|
tests_project_rootdir=root_dir,
|
|
|
|
|
project_root_path=root_dir,
|
|
|
|
|
test_framework="pytest",
|
|
|
|
|
pytest_cmd="pytest",
|
|
|
|
|
)
|
|
|
|
|
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
|
|
|
|
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
|
|
|
|
|
|
|
|
|
original_helper_code: dict[Path, str] = {}
|
|
|
|
|
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
|
|
|
|
for helper_function_path in helper_function_paths:
|
|
|
|
|
with helper_function_path.open(encoding="utf8") as f:
|
|
|
|
|
helper_code = f.read()
|
|
|
|
|
original_helper_code[helper_function_path] = helper_code
|
|
|
|
|
|
|
|
|
|
func_optimizer.args = Args()
|
|
|
|
|
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
|
|
|
|
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_code = main_file.read_text(encoding="utf-8")
|
|
|
|
|
main_file.unlink(missing_ok=True)
|
|
|
|
|
|
|
|
|
|
assert new_code == expected
|
2025-08-23 15:34:37 +00:00
|
|
|
|
2025-08-23 17:17:13 +00:00
|
|
|
def test_duplicate_global_assignments_when_reverting_helpers():
|
2025-08-23 15:34:37 +00:00
|
|
|
root_dir = Path(__file__).parent.parent.resolve()
|
|
|
|
|
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
|
|
|
|
|
|
|
|
|
|
original_code = '''"""Chunking objects not specific to a particular chunking strategy."""
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
import collections
|
|
|
|
|
import copy
|
|
|
|
|
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
|
|
|
|
|
import regex
|
|
|
|
|
from typing_extensions import Self, TypeAlias
|
|
|
|
|
from unstructured.utils import lazyproperty
|
2025-08-23 17:17:13 +00:00
|
|
|
from unstructured.documents.elements import Element
|
2025-08-23 15:34:37 +00:00
|
|
|
# ================================================================================================
|
|
|
|
|
# MODEL
|
|
|
|
|
# ================================================================================================
|
|
|
|
|
CHUNK_MAX_CHARS_DEFAULT: int = 500
|
|
|
|
|
# ================================================================================================
|
|
|
|
|
# PRE-CHUNKER
|
|
|
|
|
# ================================================================================================
|
|
|
|
|
class PreChunker:
|
|
|
|
|
"""Gathers sequential elements into pre-chunks as length constraints allow.
|
|
|
|
|
The pre-chunker's responsibilities are:
|
|
|
|
|
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
|
|
|
|
|
either side of those boundaries into different sections. In this case, the primary indicator
|
|
|
|
|
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
|
|
|
|
|
semantic boundary when `multipage_sections` is `False`.
|
|
|
|
|
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
|
|
|
|
|
into sections as big as possible without exceeding the chunk window size.
|
|
|
|
|
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
|
|
|
|
|
and only produce a section that exceeds the chunk window size when there is a single element
|
|
|
|
|
with text longer than that window.
|
|
|
|
|
A Table element is placed into a section by itself. CheckBox elements are dropped.
|
|
|
|
|
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
|
|
|
|
|
a new "section", hence the "by-title" designation.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
|
|
|
|
|
self._elements = elements
|
|
|
|
|
self._opts = opts
|
|
|
|
|
@lazyproperty
|
|
|
|
|
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
|
|
|
|
|
"""The semantic-boundary detectors to be applied to break pre-chunks."""
|
|
|
|
|
return self._opts.boundary_predicates
|
|
|
|
|
def _is_in_new_semantic_unit(self, element: Element) -> bool:
|
|
|
|
|
"""True when `element` begins a new semantic unit such as a section or page."""
|
|
|
|
|
# -- all detectors need to be called to update state and avoid double counting
|
|
|
|
|
# -- boundaries that happen to coincide, like Table and new section on same element.
|
|
|
|
|
# -- Using `any()` would short-circuit on first True.
|
|
|
|
|
semantic_boundaries = [pred(element) for pred in self._boundary_predicates]
|
|
|
|
|
return any(semantic_boundaries)
|
|
|
|
|
'''
|
|
|
|
|
main_file.write_text(original_code, encoding="utf-8")
|
|
|
|
|
optim_code = f'''```python:{main_file.relative_to(root_dir)}
|
|
|
|
|
# ================================================================================================
|
|
|
|
|
# PRE-CHUNKER
|
|
|
|
|
# ================================================================================================
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
from typing import Iterable
|
|
|
|
|
from unstructured.documents.elements import Element
|
|
|
|
|
from unstructured.utils import lazyproperty
|
|
|
|
|
class PreChunker:
|
|
|
|
|
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
|
|
|
|
|
self._elements = elements
|
|
|
|
|
self._opts = opts
|
|
|
|
|
@lazyproperty
|
|
|
|
|
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
|
|
|
|
|
"""The semantic-boundary detectors to be applied to break pre-chunks."""
|
|
|
|
|
return self._opts.boundary_predicates
|
|
|
|
|
def _is_in_new_semantic_unit(self, element: Element) -> bool:
|
|
|
|
|
"""True when `element` begins a new semantic unit such as a section or page."""
|
|
|
|
|
# Use generator expression for lower memory usage and avoid building intermediate list
|
|
|
|
|
for pred in self._boundary_predicates:
|
|
|
|
|
if pred(element):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
```
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
func = FunctionToOptimize(function_name="_is_in_new_semantic_unit", parents=[FunctionParent("PreChunker", "ClassDef")], file_path=main_file)
|
|
|
|
|
test_config = TestConfig(
|
|
|
|
|
tests_root=root_dir / "tests/pytest",
|
|
|
|
|
tests_project_rootdir=root_dir,
|
|
|
|
|
project_root_path=root_dir,
|
|
|
|
|
test_framework="pytest",
|
|
|
|
|
pytest_cmd="pytest",
|
|
|
|
|
)
|
|
|
|
|
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
|
|
|
|
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
|
|
|
|
|
|
|
|
|
original_helper_code: dict[Path, str] = {}
|
|
|
|
|
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
|
|
|
|
for helper_function_path in helper_function_paths:
|
|
|
|
|
with helper_function_path.open(encoding="utf8") as f:
|
|
|
|
|
helper_code = f.read()
|
|
|
|
|
original_helper_code[helper_function_path] = helper_code
|
|
|
|
|
|
|
|
|
|
func_optimizer.args = Args()
|
|
|
|
|
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
|
|
|
|
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_code = main_file.read_text(encoding="utf-8")
|
|
|
|
|
main_file.unlink(missing_ok=True)
|
|
|
|
|
|
|
|
|
|
expected = '''"""Chunking objects not specific to a particular chunking strategy."""
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
import collections
|
|
|
|
|
import copy
|
|
|
|
|
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
|
|
|
|
|
import regex
|
|
|
|
|
from typing_extensions import Self, TypeAlias
|
|
|
|
|
from unstructured.utils import lazyproperty
|
2025-08-23 17:17:13 +00:00
|
|
|
from unstructured.documents.elements import Element
|
2025-08-23 15:34:37 +00:00
|
|
|
# ================================================================================================
|
|
|
|
|
# MODEL
|
|
|
|
|
# ================================================================================================
|
|
|
|
|
CHUNK_MAX_CHARS_DEFAULT: int = 500
|
|
|
|
|
# ================================================================================================
|
|
|
|
|
# PRE-CHUNKER
|
|
|
|
|
# ================================================================================================
|
|
|
|
|
class PreChunker:
|
|
|
|
|
"""Gathers sequential elements into pre-chunks as length constraints allow.
|
|
|
|
|
The pre-chunker's responsibilities are:
|
|
|
|
|
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
|
|
|
|
|
either side of those boundaries into different sections. In this case, the primary indicator
|
|
|
|
|
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
|
|
|
|
|
semantic boundary when `multipage_sections` is `False`.
|
|
|
|
|
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
|
|
|
|
|
into sections as big as possible without exceeding the chunk window size.
|
|
|
|
|
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
|
|
|
|
|
and only produce a section that exceeds the chunk window size when there is a single element
|
|
|
|
|
with text longer than that window.
|
|
|
|
|
A Table element is placed into a section by itself. CheckBox elements are dropped.
|
|
|
|
|
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
|
|
|
|
|
a new "section", hence the "by-title" designation.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
|
|
|
|
|
self._elements = elements
|
|
|
|
|
self._opts = opts
|
|
|
|
|
@lazyproperty
|
|
|
|
|
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
|
|
|
|
|
"""The semantic-boundary detectors to be applied to break pre-chunks."""
|
|
|
|
|
return self._opts.boundary_predicates
|
|
|
|
|
def _is_in_new_semantic_unit(self, element: Element) -> bool:
|
|
|
|
|
"""True when `element` begins a new semantic unit such as a section or page."""
|
|
|
|
|
# Use generator expression for lower memory usage and avoid building intermediate list
|
|
|
|
|
for pred in self._boundary_predicates:
|
|
|
|
|
if pred(element):
|
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
'''
|
2025-08-25 18:36:35 +00:00
|
|
|
assert new_code == expected
|