fix bugs with docstring removal

This commit is contained in:
Saurabh Misra 2025-06-08 00:30:47 -07:00
parent b48ed5c89d
commit 9e14cfe7a0
2 changed files with 356 additions and 18 deletions

View file

@ -613,7 +613,7 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
if isinstance(node, cst.FunctionDef):
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
if qualified_name in target_functions:
new_body = remove_docstring_from_body(node.body)
new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body
return node.with_changes(body=new_body), True
return None, False
@ -632,14 +632,13 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
if isinstance(stmt, cst.FunctionDef):
qualified_name = f"{class_prefix}.{stmt.name.value}"
if qualified_name in target_functions:
new_body.append(stmt)
stmt_with_changes = stmt.with_changes(body=remove_docstring_from_body(stmt.body))
new_body.append(stmt_with_changes)
found_target = True
# If no target functions found, remove the class entirely
if not new_body or not found_target:
return None, False
return node.with_changes(
body=remove_docstring_from_body(node.body.with_changes(body=new_body))
) if new_body else None, True
return node.with_changes(body=cst.IndentedBlock(new_body)) if new_body else None, found_target
# For other nodes, we preserve them only if they contain target functions in their children.
section_names = get_section_names(node)

View file

@ -6,7 +6,6 @@ from collections import defaultdict
from pathlib import Path
import pytest
from codeflash.context.code_context_extractor import get_code_optimization_context
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
@ -30,6 +29,7 @@ class HelperClass:
def nested_method(self):
return self.name
def main_method():
return "hello"
@ -81,8 +81,9 @@ def test_code_replacement10() -> None:
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here
assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
from __future__ import annotations
@ -106,8 +107,26 @@ class MainClass:
expected_read_only_context = """
"""
expected_hashing_context = f"""
```python:{file_path.relative_to(file_path.parent)}
class HelperClass:
def helper_method(self):
return self.name
class MainClass:
def main_method(self):
self.name = HelperClass.NestedClass("test").nested_method()
return HelperClass(self.name).helper_method()
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_class_method_dependencies() -> None:
file_path = Path(__file__).resolve()
@ -122,6 +141,8 @@ def test_class_method_dependencies() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, file_path.parent.resolve())
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
from __future__ import annotations
from collections import defaultdict
@ -153,8 +174,36 @@ class Graph:
"""
expected_read_only_context = ""
expected_hashing_context = f"""
```python:{file_path.relative_to(file_path.parent.resolve())}
class Graph:
def topologicalSortUtil(self, v, visited, stack):
visited[v] = True
for i in self.graph[v]:
if visited[i] == False:
self.topologicalSortUtil(i, visited, stack)
stack.insert(0, v)
def topologicalSort(self):
visited = [False] * self.V
stack = []
for i in range(self.V):
if visited[i] == False:
self.topologicalSortUtil(i, visited, stack)
# Print contents of stack
return stack
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_bubble_sort_helper() -> None:
@ -176,6 +225,7 @@ def test_bubble_sort_helper() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, Path(__file__).resolve().parent.parent)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
import math
@ -196,8 +246,24 @@ def sort_from_another_file(arr):
"""
expected_read_only_context = ""
expected_hashing_context = """
```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py
def sorter(arr):
arr.sort()
x = math.sqrt(2)
print(x)
return arr
```
```python:code_to_optimize/code_directories/retriever/bubble_sort_imported.py
def sort_from_another_file(arr):
sorted_arr = sorter(arr)
return sorted_arr
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_flavio_typed_code_helper() -> None:
@ -366,7 +432,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
lifespan=self.__duration__,
)
'''
with tempfile.NamedTemporaryFile(mode="w") as f:
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
@ -391,6 +457,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
@ -543,8 +610,67 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
__backend__: _CacheBackendT
```
'''
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
def get_cache_or_call(
self,
*,
func: Callable[_P, Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
lifespan: datetime.timedelta,
) -> Any: # noqa: ANN401
if os.environ.get("NO_CACHE"):
return func(*args, **kwargs)
try:
key = self.hash_key(func=func, args=args, kwargs=kwargs)
except: # noqa: E722
# If we can't create a cache key, we should just call the function.
logging.warning("Failed to hash cache key for function: %s", func)
return func(*args, **kwargs)
result_pair = self.get(key=key)
if result_pair is not None:
cached_time, result = result_pair
if not os.environ.get("RE_CACHE") and (
datetime.datetime.now() < (cached_time + lifespan) # noqa: DTZ005
):
try:
return self.decode(data=result)
except CacheBackendDecodeError as e:
logging.warning("Failed to decode cache data: %s", e)
# If decoding fails we will treat this as a cache miss.
# This might happens if underlying class definition of the data changes.
self.delete(key=key)
result = func(*args, **kwargs)
try:
self.put(key=key, data=self.encode(data=result))
except CacheBackendEncodeError as e:
logging.warning("Failed to encode cache data: %s", e)
# If encoding fails, we should still return the result.
return result
class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
if "NO_CACHE" in os.environ:
return self.__wrapped__(*args, **kwargs)
os.makedirs(DEFAULT_CACHE_LOCATION, exist_ok=True)
return self.__backend__.get_cache_or_call(
func=self.__wrapped__,
args=args,
kwargs=kwargs,
lifespan=self.__duration__,
)
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_example_class() -> None:
@ -592,6 +718,8 @@ class HelperClass:
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
class MyClass:
def __init__(self):
@ -618,8 +746,21 @@ class HelperClass:
return "HelperClass" + str(self.x)
```
"""
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
def helper_method(self):
return self.x
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_example_class_token_limit_1() -> None:
@ -672,6 +813,7 @@ class HelperClass:
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
expected_read_write_context = """
class MyClass:
@ -697,9 +839,21 @@ class HelperClass:
def __repr__(self):
return "HelperClass" + str(self.x)
```
"""
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
def helper_method(self):
return self.x
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_example_class_token_limit_2() -> None:
@ -752,6 +906,7 @@ class HelperClass:
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root, 8000, 100000)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
expected_read_write_context = """
class MyClass:
@ -769,8 +924,20 @@ class HelperClass:
return self.x
"""
expected_read_only_context = ""
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
def helper_method(self):
return self.x
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_example_class_token_limit_3() -> None:
@ -823,6 +990,7 @@ class HelperClass:
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
def test_example_class_token_limit_4() -> None:
string_filler = " ".join(
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
@ -875,6 +1043,7 @@ class HelperClass:
with pytest.raises(ValueError, match="Testgen code context has exceeded token limit, cannot proceed"):
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
def test_repo_helper() -> None:
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
path_to_file = project_root / "main.py"
@ -889,6 +1058,7 @@ def test_repo_helper() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
import math
import requests
@ -938,9 +1108,38 @@ class DataProcessor:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
"""
expected_hashing_context = f"""
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
def process_data(self, raw_data: str) -> str:
\"\"\"Process raw data by converting it to uppercase.\"\"\"
return raw_data.upper()
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
\"\"\"Add a prefix to the processed data.\"\"\"
return prefix + data
```
```python:{path_to_file.relative_to(project_root)}
def fetch_and_process_data():
# Use the global variable for the request
response = requests.get(API_URL)
response.raise_for_status()
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
processed = processor.add_prefix(processed)
return processed
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_repo_helper_of_helper() -> None:
@ -958,6 +1157,7 @@ def test_repo_helper_of_helper() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
import math
from transform_utils import DataTransformer
@ -1014,10 +1214,38 @@ class DataTransformer:
self.data = data
return self.data
```
"""
expected_hashing_context = f"""
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
def process_data(self, raw_data: str) -> str:
\"\"\"Process raw data by converting it to uppercase.\"\"\"
return raw_data.upper()
def transform_data(self, data: str) -> str:
\"\"\"Transform the processed data\"\"\"
return DataTransformer().transform(data)
```
```python:{path_to_file.relative_to(project_root)}
def fetch_and_transform_data():
# Use the global variable for the request
response = requests.get(API_URL)
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
transformed = processor.transform_data(processed)
return transformed
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_repo_helper_of_helper_same_class() -> None:
@ -1034,6 +1262,7 @@ def test_repo_helper_of_helper_same_class() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
import math
from transform_utils import DataTransformer
@ -1078,10 +1307,20 @@ class DataProcessor:
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
"""
expected_hashing_context = f"""
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
def transform_data_own_method(self, data: str) -> str:
\"\"\"Transform the processed data using own method\"\"\"
return DataTransformer().transform_using_own_method(data)
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_repo_helper_of_helper_same_file() -> None:
@ -1098,6 +1337,7 @@ def test_repo_helper_of_helper_same_file() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
import math
from transform_utils import DataTransformer
@ -1137,10 +1377,20 @@ class DataProcessor:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
"""
expected_hashing_context = f"""
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
def transform_data_same_file_function(self, data: str) -> str:
\"\"\"Transform the processed data using a function from the same file\"\"\"
return DataTransformer().transform_using_same_file_function(data)
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_repo_helper_all_same_file() -> None:
@ -1156,6 +1406,7 @@ def test_repo_helper_all_same_file() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
class DataTransformer:
def __init__(self):
@ -1181,10 +1432,27 @@ class DataTransformer:
return self.data
```
"""
expected_hashing_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def transform_using_own_method(self, data):
return self.transform(data)
def transform_data_all_same_file(self, data):
new_data = update_data(data)
return self.transform_using_own_method(new_data)
def update_data(data):
return data + " updated"
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_repo_helper_circular_dependency() -> None:
@ -1201,6 +1469,7 @@ def test_repo_helper_circular_dependency() -> None:
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
import math
from transform_utils import DataTransformer
@ -1240,10 +1509,26 @@ class DataProcessor:
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```
"""
expected_hashing_context = f"""
```python:utils.py
class DataProcessor:
def circular_dependency(self, data: str) -> str:
return DataTransformer().circular_dependency(data)
```
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def circular_dependency(self, data):
return DataProcessor().circular_dependency(data)
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_indirect_init_helper() -> None:
code = """
@ -1282,6 +1567,7 @@ def outside_method():
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
class MyClass:
def __init__(self):
@ -1295,9 +1581,18 @@ class MyClass:
def outside_method():
return 1
```
"""
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def target_method(self):
return self.x + self.y
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_direct_module_import() -> None:
project_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "retriever"
@ -1311,9 +1606,9 @@ def test_direct_module_import() -> None:
ending_line=None,
)
code_ctx = get_code_optimization_context(function_to_optimize, project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_only_context = """
```python:utils.py
@ -1336,6 +1631,26 @@ class DataProcessor:
\"\"\"Transform the processed data\"\"\"
return DataTransformer().transform(data)
```"""
expected_hashing_context = """
```python:main.py
def fetch_and_transform_data():
# Use the global variable for the request
response = requests.get(API_URL)
raw_data = response.text
# Use code from another file (utils.py)
processor = DataProcessor()
processed = processor.process_data(raw_data)
transformed = processor.transform_data(processed)
return transformed
```
```python:import_test.py
def function_to_optimize():
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
```
"""
expected_read_write_context = """
import requests
from globals import API_URL
@ -1362,9 +1677,11 @@ def function_to_optimize():
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_module_import_optimization() -> None:
main_code = '''
main_code = """
import utility_module
class Calculator:
@ -1391,9 +1708,9 @@ class Calculator:
return self.subtract(x, y)
else:
return None
'''
"""
utility_module_code = '''
utility_module_code = """
import sys
import platform
import logging
@ -1466,7 +1783,7 @@ def get_system_details():
"default_precision": DEFAULT_PRECISION,
"python_version": sys.version
}
'''
"""
# Create a temporary directory for the test
with tempfile.TemporaryDirectory() as temp_dir:
@ -1515,6 +1832,7 @@ def get_system_details():
# Get the code optimization context
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
# The expected contexts
expected_read_write_context = """
import utility_module
@ -1579,13 +1897,34 @@ def select_precision(precision, fallback_precision):
else:
return DEFAULT_PRECISION
```
"""
expected_hashing_context = """
```python:main_module.py
class Calculator:
def add(self, a, b):
return a + b
def subtract(self, a, b):
return a - b
def calculate(self, operation, x, y):
if operation == "add":
return self.add(x, y)
elif operation == "subtract":
return self.subtract(x, y)
else:
return None
```
"""
# Verify the contexts match the expected values
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
def test_module_import_init_fto() -> None:
main_code = '''
main_code = """
import utility_module
class Calculator:
@ -1612,9 +1951,9 @@ class Calculator:
return self.subtract(x, y)
else:
return None
'''
"""
utility_module_code = '''
utility_module_code = """
import sys
import platform
import logging
@ -1687,7 +2026,7 @@ def get_system_details():
"default_precision": DEFAULT_PRECISION,
"python_version": sys.version
}
'''
"""
# Create a temporary directory for the test
with tempfile.TemporaryDirectory() as temp_dir:
@ -1791,4 +2130,4 @@ except ImportError:
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()