fix bugs with docstring removal
This commit is contained in:
parent
b48ed5c89d
commit
9e14cfe7a0
2 changed files with 356 additions and 18 deletions
|
|
@ -613,7 +613,7 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
|
|||
if isinstance(node, cst.FunctionDef):
|
||||
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||
if qualified_name in target_functions:
|
||||
new_body = remove_docstring_from_body(node.body)
|
||||
new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body
|
||||
return node.with_changes(body=new_body), True
|
||||
return None, False
|
||||
|
||||
|
|
@ -632,14 +632,13 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
|
|||
if isinstance(stmt, cst.FunctionDef):
|
||||
qualified_name = f"{class_prefix}.{stmt.name.value}"
|
||||
if qualified_name in target_functions:
|
||||
new_body.append(stmt)
|
||||
stmt_with_changes = stmt.with_changes(body=remove_docstring_from_body(stmt.body))
|
||||
new_body.append(stmt_with_changes)
|
||||
found_target = True
|
||||
# If no target functions found, remove the class entirely
|
||||
if not new_body or not found_target:
|
||||
return None, False
|
||||
return node.with_changes(
|
||||
body=remove_docstring_from_body(node.body.with_changes(body=new_body))
|
||||
) if new_body else None, True
|
||||
return node.with_changes(body=cst.IndentedBlock(new_body)) if new_body else None, found_target
|
||||
|
||||
# For other nodes, we preserve them only if they contain target functions in their children.
|
||||
section_names = get_section_names(node)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from collections import defaultdict
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
|
@ -30,6 +29,7 @@ class HelperClass:
|
|||
def nested_method(self):
|
||||
return self.name
|
||||
|
||||
|
||||
def main_method():
|
||||
return "hello"
|
||||
|
||||
|
|
@ -81,8 +81,9 @@ def test_code_replacement10() -> None:
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
|
||||
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
|
||||
assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here
|
||||
assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
|
||||
expected_read_write_context = """
|
||||
from __future__ import annotations
|
||||
|
|
@ -106,8 +107,26 @@ class MainClass:
|
|||
expected_read_only_context = """
|
||||
"""
|
||||
|
||||
expected_hashing_context = f"""
|
||||
```python:{file_path.relative_to(file_path.parent)}
|
||||
class HelperClass:
|
||||
|
||||
def helper_method(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class MainClass:
|
||||
|
||||
def main_method(self):
|
||||
self.name = HelperClass.NestedClass("test").nested_method()
|
||||
return HelperClass(self.name).helper_method()
|
||||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_class_method_dependencies() -> None:
|
||||
file_path = Path(__file__).resolve()
|
||||
|
|
@ -122,6 +141,8 @@ def test_class_method_dependencies() -> None:
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, file_path.parent.resolve())
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
|
||||
expected_read_write_context = """
|
||||
from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
|
|
@ -153,8 +174,36 @@ class Graph:
|
|||
|
||||
"""
|
||||
expected_read_only_context = ""
|
||||
|
||||
expected_hashing_context = f"""
|
||||
```python:{file_path.relative_to(file_path.parent.resolve())}
|
||||
class Graph:
|
||||
|
||||
def topologicalSortUtil(self, v, visited, stack):
|
||||
visited[v] = True
|
||||
|
||||
for i in self.graph[v]:
|
||||
if visited[i] == False:
|
||||
self.topologicalSortUtil(i, visited, stack)
|
||||
|
||||
stack.insert(0, v)
|
||||
|
||||
def topologicalSort(self):
|
||||
visited = [False] * self.V
|
||||
stack = []
|
||||
|
||||
for i in range(self.V):
|
||||
if visited[i] == False:
|
||||
self.topologicalSortUtil(i, visited, stack)
|
||||
|
||||
# Print contents of stack
|
||||
return stack
|
||||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_bubble_sort_helper() -> None:
|
||||
|
|
@ -176,6 +225,7 @@ def test_bubble_sort_helper() -> None:
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, Path(__file__).resolve().parent.parent)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
|
||||
expected_read_write_context = """
|
||||
import math
|
||||
|
|
@ -196,8 +246,24 @@ def sort_from_another_file(arr):
|
|||
"""
|
||||
expected_read_only_context = ""
|
||||
|
||||
expected_hashing_context = """
|
||||
```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py
|
||||
def sorter(arr):
|
||||
arr.sort()
|
||||
x = math.sqrt(2)
|
||||
print(x)
|
||||
return arr
|
||||
```
|
||||
```python:code_to_optimize/code_directories/retriever/bubble_sort_imported.py
|
||||
def sort_from_another_file(arr):
|
||||
sorted_arr = sorter(arr)
|
||||
return sorted_arr
|
||||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_flavio_typed_code_helper() -> None:
|
||||
|
|
@ -366,7 +432,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
lifespan=self.__duration__,
|
||||
)
|
||||
'''
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
file_path = Path(f.name).resolve()
|
||||
|
|
@ -391,6 +457,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
||||
|
||||
|
|
@ -543,8 +610,67 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
__backend__: _CacheBackendT
|
||||
```
|
||||
'''
|
||||
expected_hashing_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
|
||||
|
||||
def get_cache_or_call(
|
||||
self,
|
||||
*,
|
||||
func: Callable[_P, Any],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
lifespan: datetime.timedelta,
|
||||
) -> Any: # noqa: ANN401
|
||||
if os.environ.get("NO_CACHE"):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
try:
|
||||
key = self.hash_key(func=func, args=args, kwargs=kwargs)
|
||||
except: # noqa: E722
|
||||
# If we can't create a cache key, we should just call the function.
|
||||
logging.warning("Failed to hash cache key for function: %s", func)
|
||||
return func(*args, **kwargs)
|
||||
result_pair = self.get(key=key)
|
||||
|
||||
if result_pair is not None:
|
||||
cached_time, result = result_pair
|
||||
if not os.environ.get("RE_CACHE") and (
|
||||
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
|
||||
):
|
||||
try:
|
||||
return self.decode(data=result)
|
||||
except CacheBackendDecodeError as e:
|
||||
logging.warning("Failed to decode cache data: %s", e)
|
||||
# If decoding fails we will treat this as a cache miss.
|
||||
# This might happens if underlying class definition of the data changes.
|
||||
self.delete(key=key)
|
||||
result = func(*args, **kwargs)
|
||||
try:
|
||||
self.put(key=key, data=self.encode(data=result))
|
||||
except CacheBackendEncodeError as e:
|
||||
logging.warning("Failed to encode cache data: %s", e)
|
||||
# If encoding fails, we should still return the result.
|
||||
return result
|
||||
|
||||
|
||||
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
||||
|
||||
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
if "NO_CACHE" in os.environ:
|
||||
return self.__wrapped__(*args, **kwargs)
|
||||
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
|
||||
return self.__backend__.get_cache_or_call(
|
||||
func=self.__wrapped__,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
lifespan=self.__duration__,
|
||||
)
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_example_class() -> None:
|
||||
|
|
@ -592,6 +718,8 @@ class HelperClass:
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
|
||||
expected_read_write_context = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
|
|
@ -618,8 +746,21 @@ class HelperClass:
|
|||
return "HelperClass" + str(self.x)
|
||||
```
|
||||
"""
|
||||
expected_hashing_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
def target_method(self):
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_example_class_token_limit_1() -> None:
|
||||
|
|
@ -672,6 +813,7 @@ class HelperClass:
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
|
||||
expected_read_write_context = """
|
||||
class MyClass:
|
||||
|
|
@ -697,9 +839,21 @@ class HelperClass:
|
|||
def __repr__(self):
|
||||
return "HelperClass" + str(self.x)
|
||||
```
|
||||
"""
|
||||
expected_hashing_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
def target_method(self):
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_example_class_token_limit_2() -> None:
|
||||
|
|
@ -752,6 +906,7 @@ class HelperClass:
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
|
||||
expected_read_write_context = """
|
||||
class MyClass:
|
||||
|
|
@ -769,8 +924,20 @@ class HelperClass:
|
|||
return self.x
|
||||
"""
|
||||
expected_read_only_context = ""
|
||||
expected_hashing_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
def target_method(self):
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_example_class_token_limit_3() -> None:
|
||||
|
|
@ -823,6 +990,7 @@ class HelperClass:
|
|||
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
|
||||
|
||||
def test_example_class_token_limit_4() -> None:
|
||||
string_filler = " ".join(
|
||||
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
|
||||
|
|
@ -875,6 +1043,7 @@ class HelperClass:
|
|||
with pytest.raises(ValueError, match="Testgen code context has exceeded token limit, cannot proceed"):
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
|
||||
|
||||
def test_repo_helper() -> None:
|
||||
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
||||
path_to_file = project_root / "main.py"
|
||||
|
|
@ -889,6 +1058,7 @@ def test_repo_helper() -> None:
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
import math
|
||||
import requests
|
||||
|
|
@ -938,9 +1108,38 @@ class DataProcessor:
|
|||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
```
|
||||
"""
|
||||
expected_hashing_context = f"""
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
class DataProcessor:
|
||||
|
||||
def process_data(self, raw_data: str) -> str:
|
||||
\"\"\"Process raw data by converting it to uppercase.\"\"\"
|
||||
return raw_data.upper()
|
||||
|
||||
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
|
||||
\"\"\"Add a prefix to the processed data.\"\"\"
|
||||
return prefix + data
|
||||
```
|
||||
```python:{path_to_file.relative_to(project_root)}
|
||||
def fetch_and_process_data():
|
||||
# Use the global variable for the request
|
||||
response = requests.get(API_URL)
|
||||
response.raise_for_status()
|
||||
|
||||
raw_data = response.text
|
||||
|
||||
# Use code from another file (utils.py)
|
||||
processor = DataProcessor()
|
||||
processed = processor.process_data(raw_data)
|
||||
processed = processor.add_prefix(processed)
|
||||
|
||||
return processed
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_repo_helper_of_helper() -> None:
|
||||
|
|
@ -958,6 +1157,7 @@ def test_repo_helper_of_helper() -> None:
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
|
@ -1014,10 +1214,38 @@ class DataTransformer:
|
|||
self.data = data
|
||||
return self.data
|
||||
```
|
||||
"""
|
||||
expected_hashing_context = f"""
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
class DataProcessor:
|
||||
|
||||
def process_data(self, raw_data: str) -> str:
|
||||
\"\"\"Process raw data by converting it to uppercase.\"\"\"
|
||||
return raw_data.upper()
|
||||
|
||||
def transform_data(self, data: str) -> str:
|
||||
\"\"\"Transform the processed data\"\"\"
|
||||
return DataTransformer().transform(data)
|
||||
```
|
||||
```python:{path_to_file.relative_to(project_root)}
|
||||
def fetch_and_transform_data():
|
||||
# Use the global variable for the request
|
||||
response = requests.get(API_URL)
|
||||
|
||||
raw_data = response.text
|
||||
|
||||
# Use code from another file (utils.py)
|
||||
processor = DataProcessor()
|
||||
processed = processor.process_data(raw_data)
|
||||
transformed = processor.transform_data(processed)
|
||||
|
||||
return transformed
|
||||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_repo_helper_of_helper_same_class() -> None:
|
||||
|
|
@ -1034,6 +1262,7 @@ def test_repo_helper_of_helper_same_class() -> None:
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
|
@ -1078,10 +1307,20 @@ class DataProcessor:
|
|||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
```
|
||||
|
||||
"""
|
||||
expected_hashing_context = f"""
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
class DataProcessor:
|
||||
|
||||
def transform_data_own_method(self, data: str) -> str:
|
||||
\"\"\"Transform the processed data using own method\"\"\"
|
||||
return DataTransformer().transform_using_own_method(data)
|
||||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_repo_helper_of_helper_same_file() -> None:
|
||||
|
|
@ -1098,6 +1337,7 @@ def test_repo_helper_of_helper_same_file() -> None:
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
|
@ -1137,10 +1377,20 @@ class DataProcessor:
|
|||
\"\"\"Return a string representation of the DataProcessor.\"\"\"
|
||||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
```
|
||||
"""
|
||||
expected_hashing_context = f"""
|
||||
```python:{path_to_utils.relative_to(project_root)}
|
||||
class DataProcessor:
|
||||
|
||||
def transform_data_same_file_function(self, data: str) -> str:
|
||||
\"\"\"Transform the processed data using a function from the same file\"\"\"
|
||||
return DataTransformer().transform_using_same_file_function(data)
|
||||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_repo_helper_all_same_file() -> None:
|
||||
|
|
@ -1156,6 +1406,7 @@ def test_repo_helper_all_same_file() -> None:
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
class DataTransformer:
|
||||
def __init__(self):
|
||||
|
|
@ -1181,10 +1432,27 @@ class DataTransformer:
|
|||
return self.data
|
||||
```
|
||||
|
||||
"""
|
||||
expected_hashing_context = f"""
|
||||
```python:{path_to_transform_utils.relative_to(project_root)}
|
||||
class DataTransformer:
|
||||
|
||||
def transform_using_own_method(self, data):
|
||||
return self.transform(data)
|
||||
|
||||
def transform_data_all_same_file(self, data):
|
||||
new_data = update_data(data)
|
||||
return self.transform_using_own_method(new_data)
|
||||
|
||||
|
||||
def update_data(data):
|
||||
return data + " updated"
|
||||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_repo_helper_circular_dependency() -> None:
|
||||
|
|
@ -1201,6 +1469,7 @@ def test_repo_helper_circular_dependency() -> None:
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
import math
|
||||
from transform_utils import DataTransformer
|
||||
|
|
@ -1240,10 +1509,26 @@ class DataProcessor:
|
|||
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
|
||||
```
|
||||
|
||||
"""
|
||||
expected_hashing_context = f"""
|
||||
```python:utils.py
|
||||
class DataProcessor:
|
||||
|
||||
def circular_dependency(self, data: str) -> str:
|
||||
return DataTransformer().circular_dependency(data)
|
||||
```
|
||||
```python:{path_to_transform_utils.relative_to(project_root)}
|
||||
class DataTransformer:
|
||||
|
||||
def circular_dependency(self, data):
|
||||
return DataProcessor().circular_dependency(data)
|
||||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_indirect_init_helper() -> None:
|
||||
code = """
|
||||
|
|
@ -1282,6 +1567,7 @@ def outside_method():
|
|||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
expected_read_write_context = """
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
|
|
@ -1295,9 +1581,18 @@ class MyClass:
|
|||
def outside_method():
|
||||
return 1
|
||||
```
|
||||
"""
|
||||
expected_hashing_context = f"""
|
||||
```python:{file_path.relative_to(opt.args.project_root)}
|
||||
class MyClass:
|
||||
def target_method(self):
|
||||
return self.x + self.y
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_direct_module_import() -> None:
|
||||
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
|
||||
|
|
@ -1311,9 +1606,9 @@ def test_direct_module_import() -> None:
|
|||
ending_line=None,
|
||||
)
|
||||
|
||||
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
|
||||
expected_read_only_context = """
|
||||
```python:utils.py
|
||||
|
|
@ -1336,6 +1631,26 @@ class DataProcessor:
|
|||
\"\"\"Transform the processed data\"\"\"
|
||||
return DataTransformer().transform(data)
|
||||
```"""
|
||||
expected_hashing_context = """
|
||||
```python:main.py
|
||||
def fetch_and_transform_data():
|
||||
# Use the global variable for the request
|
||||
response = requests.get(API_URL)
|
||||
|
||||
raw_data = response.text
|
||||
|
||||
# Use code from another file (utils.py)
|
||||
processor = DataProcessor()
|
||||
processed = processor.process_data(raw_data)
|
||||
transformed = processor.transform_data(processed)
|
||||
|
||||
return transformed
|
||||
```
|
||||
```python:import_test.py
|
||||
def function_to_optimize():
|
||||
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
|
||||
```
|
||||
"""
|
||||
expected_read_write_context = """
|
||||
import requests
|
||||
from globals import API_URL
|
||||
|
|
@ -1362,9 +1677,11 @@ def function_to_optimize():
|
|||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_module_import_optimization() -> None:
|
||||
main_code = '''
|
||||
main_code = """
|
||||
import utility_module
|
||||
|
||||
class Calculator:
|
||||
|
|
@ -1391,9 +1708,9 @@ class Calculator:
|
|||
return self.subtract(x, y)
|
||||
else:
|
||||
return None
|
||||
'''
|
||||
"""
|
||||
|
||||
utility_module_code = '''
|
||||
utility_module_code = """
|
||||
import sys
|
||||
import platform
|
||||
import logging
|
||||
|
|
@ -1466,7 +1783,7 @@ def get_system_details():
|
|||
"default_precision": DEFAULT_PRECISION,
|
||||
"python_version": sys.version
|
||||
}
|
||||
'''
|
||||
"""
|
||||
|
||||
# Create a temporary directory for the test
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
|
|
@ -1515,6 +1832,7 @@ def get_system_details():
|
|||
# Get the code optimization context
|
||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
# The expected contexts
|
||||
expected_read_write_context = """
|
||||
import utility_module
|
||||
|
|
@ -1579,13 +1897,34 @@ def select_precision(precision, fallback_precision):
|
|||
else:
|
||||
return DEFAULT_PRECISION
|
||||
```
|
||||
"""
|
||||
expected_hashing_context = """
|
||||
```python:main_module.py
|
||||
class Calculator:
|
||||
|
||||
def add(self, a, b):
|
||||
return a + b
|
||||
|
||||
def subtract(self, a, b):
|
||||
return a - b
|
||||
|
||||
def calculate(self, operation, x, y):
|
||||
if operation == "add":
|
||||
return self.add(x, y)
|
||||
elif operation == "subtract":
|
||||
return self.subtract(x, y)
|
||||
else:
|
||||
return None
|
||||
```
|
||||
"""
|
||||
# Verify the contexts match the expected values
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
||||
def test_module_import_init_fto() -> None:
|
||||
main_code = '''
|
||||
main_code = """
|
||||
import utility_module
|
||||
|
||||
class Calculator:
|
||||
|
|
@ -1612,9 +1951,9 @@ class Calculator:
|
|||
return self.subtract(x, y)
|
||||
else:
|
||||
return None
|
||||
'''
|
||||
"""
|
||||
|
||||
utility_module_code = '''
|
||||
utility_module_code = """
|
||||
import sys
|
||||
import platform
|
||||
import logging
|
||||
|
|
@ -1687,7 +2026,7 @@ def get_system_details():
|
|||
"default_precision": DEFAULT_PRECISION,
|
||||
"python_version": sys.version
|
||||
}
|
||||
'''
|
||||
"""
|
||||
|
||||
# Create a temporary directory for the test
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
|
|
@ -1791,4 +2130,4 @@ except ImportError:
|
|||
```
|
||||
"""
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
|
|
|
|||
Loading…
Reference in a new issue