added docstring removal

This commit is contained in:
Alvin Ryanputra 2024-12-26 14:06:05 -08:00
parent 693c150262
commit e10f13a83a
5 changed files with 429 additions and 22 deletions

View file

@ -6,6 +6,7 @@ from pathlib import Path
import jedi
import libcst as cst
import tiktoken
from jedi.api.classes import Name
from codeflash.cli_cmds.console import logger
@ -106,20 +107,49 @@ def get_code_optimization_context(
)
if read_only_code_with_imports.code:
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
# final_read_writable_codestring = CodeString(code=final_read_writable_code)
# tokenizer = tiktoken.encoding_for_model("gpt-4o")
# final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
# if final_read_writable_tokens > token_limit:
# logger.debug(
# "Read writable code exceeded token limit, removing helper functions and only keeping function to optimize"
# )
# try:
# read_writable_code = get_read_writable_code(og_code_containing_helpers, qualified_function_names)
# except ValueError as e:
# logger.debug(f"Error while getting read-writable code: {e}")
# continue
tokenizer = tiktoken.encoding_for_model("gpt-4o")
final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
print("final_read_writable_tokens", final_read_writable_tokens)
if final_read_writable_tokens > token_limit:
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_code_markdown.markdown))
total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens
print("total_tokens", total_tokens)
print(read_only_code_markdown.markdown)
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
if total_tokens <= token_limit:
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
try:
read_only_code = get_read_only_code(og_code_containing_helpers, qualified_function_names)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
read_only_code_with_imports = CodeString(
code=add_needed_imports_from_module(
src_module_code=og_code_containing_helpers,
dst_module_code=read_only_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=qualified_function_names,
),
file_path=Path(file_path),
)
if read_only_code_with_imports.code:
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_code_markdown.markdown))
total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens
print("total_tokens after removal", total_tokens)
if total_tokens <= token_limit:
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
logger.debug("Code context has exceeded token limit, removing read-only code")
return CodeString(code=final_read_writable_code).code, ""
def is_dunder_method(name: str) -> bool:
@ -132,6 +162,17 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
return [sec for sec in possible_sections if hasattr(node, sec)]
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
"""Removes the docstring from a body of statements if present."""
print(indented_block)
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
return indented_block
first_stmt = indented_block.body[0].body[0]
if isinstance(first_stmt, cst.Expr) and isinstance(first_stmt.value, cst.SimpleString):
return indented_block.with_changes(body=indented_block.body[1:])
return indented_block
def prune_cst_for_read_writable_code(
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
) -> tuple[cst.CSTNode | None, bool]:
@ -238,6 +279,9 @@ def prune_cst_for_read_only_code(
return None, True
# Keep only dunder methods
if is_dunder_method(node.name.value):
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
new_body = remove_docstring_from_body(node.body)
return node.with_changes(body=new_body), False
return node, False
return None, False
@ -256,7 +300,7 @@ def prune_cst_for_read_only_code(
new_body = []
for stmt in node.body.body:
filtered, found_target = prune_cst_for_read_only_code(
stmt, target_functions, class_prefix, remove_docstrings
stmt, target_functions, class_prefix, remove_docstrings=remove_docstrings
)
found_in_class |= found_target
@ -271,6 +315,10 @@ def prune_cst_for_read_only_code(
if not found_in_class:
return None, False
if remove_docstrings:
return node.with_changes(
body=remove_docstring_from_body(node.body.with_changes(body=new_body))
) if new_body else None, True
return node.with_changes(body=node.body.with_changes(body=new_body)) if new_body else None, True
# For other nodes, keep the node and recursively filter children
@ -288,7 +336,7 @@ def prune_cst_for_read_only_code(
section_found_target = False
for child in original_content:
filtered, found_target = prune_cst_for_read_only_code(
child, target_functions, prefix, remove_docstrings
child, target_functions, prefix, remove_docstrings=remove_docstrings
)
if filtered:
new_children.append(filtered)
@ -299,7 +347,7 @@ def prune_cst_for_read_only_code(
updates[section] = new_children
elif original_content is not None:
filtered, found_target = prune_cst_for_read_only_code(
original_content, target_functions, prefix, remove_docstrings
original_content, target_functions, prefix, remove_docstrings=remove_docstrings
)
found_any_target |= found_target
if filtered:
@ -316,7 +364,9 @@ def get_read_only_code(code: str, target_functions: set[str], remove_docstrings:
class contextual information, and other module scoped variables.
"""
module = cst.parse_module(code)
filtered_node, found_target = prune_cst_for_read_only_code(module, target_functions, remove_docstrings)
filtered_node, found_target = prune_cst_for_read_only_code(
module, target_functions, remove_docstrings=remove_docstrings
)
if not found_target:
raise ValueError("No target functions found in the provided code")
if filtered_node and isinstance(filtered_node, cst.Module):

View file

@ -715,13 +715,17 @@ class Optimizer:
contextual_dunder_methods.update(helper_dunder_methods)
# Will eventually refactor to use this function instead of the above
read_writable_code, read_only_context_code = code_context_extractor.get_code_optimization_context(
function_to_optimize, project_root
)
try:
read_writable_code, read_only_context_code = code_context_extractor.get_code_optimization_context(
function_to_optimize, project_root
)
except ValueError as e:
return Failure(str(e))
logger.info("Read-writable code:")
code_print(read_writable_code)
logger.info("Read-only context code:")
# code_print(read_only_context_code)
code_print(read_only_context_code)
return Success(
CodeOptimizationContext(
code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,

View file

@ -6,10 +6,10 @@ from collections import defaultdict
from pathlib import Path
from textwrap import dedent
from codeflash.context.code_context_extractor import get_code_optimization_context
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.optimization.optimizer import Optimizer
from codeflash.context.code_context_extractor import get_code_optimization_context
class HelperClass:
@ -609,3 +609,159 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
'''
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
def test_example_class() -> None:
code = """
class MyClass:
\"\"\"A class with a helper method.\"\"\"
def __init__(self):
self.x = 1
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
def helper_method(self):
return self.x
"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, opt.args.project_root
)
expected_read_write_context = """
class MyClass:
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
def helper_method(self):
return self.x
"""
expected_read_only_context = f"""
```python:{file_path}
class MyClass:
\"\"\"A class with a helper method.\"\"\"
def __init__(self):
self.x = 1
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
```
"""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
def test_example_class_token_limit_1() -> None:
docstring_filler = "This is a long docstring that will be used to fill up the token limit."
docstring_filler_multiplied = " ".join([docstring_filler for _ in range(1000)])
code = f"""
class MyClass:
\"\"\"A class with a helper method.
{docstring_filler_multiplied}\"\"\"
def __init__(self):
self.x = 1
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
def helper_method(self):
return self.x
"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, opt.args.project_root
)
expected_read_write_context = """
class MyClass:
def target_method(self):
y = HelperClass().helper_method()
class HelperClass:
def helper_method(self):
return self.x
"""
expected_read_only_context = f"""
```python:{file_path}
class MyClass:
\"\"\"A class with a helper method.
{docstring_filler_multiplied}\"\"\"
def __init__(self):
self.x = 1
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
```
"""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()

View file

@ -51,6 +51,90 @@ def test_dunder_methods() -> None:
assert dedent(expected).strip() == output.strip()
def test_dunder_methods_remove_docstring() -> None:
code = """
class TestClass:
def __init__(self):
\"\"\"Constructor for TestClass.\"\"\"
self.x = 42
def __str__(self):
\"\"\"String representation of TestClass.\"\"\"
return f"Value: {self.x}"
def target_method(self):
print("stub me")
"""
expected = """
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, remove_docstrings=True)
assert dedent(expected).strip() == output.strip()
def test_class_remove_docstring() -> None:
code = """
class TestClass:
\"\"\"Class docstring.\"\"\"
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
def target_method(self):
print("stub me")
"""
expected = """
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, remove_docstrings=True)
assert dedent(expected).strip() == output.strip()
def test_mixed_remove_docstring() -> None:
code = """
class TestClass:
\"\"\"Class docstring.\"\"\"
def __init__(self):
self.x = 42
def __str__(self):
\"\"\"String representation of TestClass.\"\"\"
return f"Value: {self.x}"
def target_method(self):
\"\"\"target method docstring.\"\"\"
print("stub me")
"""
expected = """
class TestClass:
def __init__(self):
self.x = 42
def __str__(self):
return f"Value: {self.x}"
"""
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, remove_docstrings=True)
assert dedent(expected).strip() == output.strip()
def test_target_in_nested_class() -> None:
"""Test that attempting to find a target in a nested class raises an error."""
code = """
@ -603,3 +687,114 @@ def test_simplified_complete_implementation() -> None:
output = get_read_only_code(dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"})
assert dedent(expected).strip() == output.strip()
def test_simplified_complete_implementation_no_docstring() -> None:
code = """
class DataProcessor:
\"\"\"A simple data processing class.\"\"\"
def __init__(self, data: Dict[str, Any]) -> None:
self.data = data
self._processed = False
self.result = None
def __repr__(self) -> str:
return f"DataProcessor(processed={self._processed})"
def target_method(self, key: str) -> Optional[Any]:
\"\"\"Process and retrieve a specific key from the data.\"\"\"
if not self._processed:
self._process_data()
return self.result.get(key) if self.result else None
def _process_data(self) -> None:
\"\"\"Internal method to process the data.\"\"\"
processed = {}
for key, value in self.data.items():
if isinstance(value, (int, float)):
processed[key] = value * 2
elif isinstance(value, str):
processed[key] = value.upper()
else:
processed[key] = value
self.result = processed
self._processed = True
def to_json(self) -> str:
\"\"\"Convert the processed data to JSON string.\"\"\"
if not self._processed:
self._process_data()
return json.dumps(self.result)
try:
sample_data = {"number": 42, "text": "hello"}
processor = DataProcessor(sample_data)
class ResultHandler:
def __init__(self, processor: DataProcessor):
self.processor = processor
self.cache = {}
def __str__(self) -> str:
return f"ResultHandler(cache_size={len(self.cache)})"
def target_method(self, key: str) -> Optional[Any]:
\"\"\"Retrieve and cache results for a key.\"\"\"
if key not in self.cache:
self.cache[key] = self.processor.target_method(key)
return self.cache[key]
def clear_cache(self) -> None:
\"\"\"Clear the internal cache.\"\"\"
self.cache.clear()
def get_stats(self) -> Dict[str, int]:
\"\"\"Get cache statistics.\"\"\"
return {
"cache_size": len(self.cache),
"hits": sum(1 for v in self.cache.values() if v is not None)
}
except Exception as e:
class ResultHandler:
def __init__(self):
self.error = str(e)
def target_method(self, key: str) -> None:
raise RuntimeError(f"Failed to initialize: {self.error}")
"""
expected = """
class DataProcessor:
def __init__(self, data: Dict[str, Any]) -> None:
self.data = data
self._processed = False
self.result = None
def __repr__(self) -> str:
return f"DataProcessor(processed={self._processed})"
try:
sample_data = {"number": 42, "text": "hello"}
processor = DataProcessor(sample_data)
class ResultHandler:
def __init__(self, processor: DataProcessor):
self.processor = processor
self.cache = {}
def __str__(self) -> str:
return f"ResultHandler(cache_size={len(self.cache)})"
except Exception as e:
class ResultHandler:
def __init__(self):
self.error = str(e)
"""
output = get_read_only_code(
dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, remove_docstrings=True
)
assert dedent(expected).strip() == output.strip()

View file

@ -221,3 +221,5 @@ def test_module_var() -> None:
output = get_read_writable_code(dedent(code), {"target_function"})
assert dedent(expected).strip() == output.strip()