added docstring removal
This commit is contained in:
parent
693c150262
commit
e10f13a83a
5 changed files with 429 additions and 22 deletions
|
|
@ -6,6 +6,7 @@ from pathlib import Path
|
|||
|
||||
import jedi
|
||||
import libcst as cst
|
||||
import tiktoken
|
||||
from jedi.api.classes import Name
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
|
@ -106,20 +107,49 @@ def get_code_optimization_context(
|
|||
)
|
||||
if read_only_code_with_imports.code:
|
||||
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
|
||||
# final_read_writable_codestring = CodeString(code=final_read_writable_code)
|
||||
# tokenizer = tiktoken.encoding_for_model("gpt-4o")
|
||||
# final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
|
||||
# if final_read_writable_tokens > token_limit:
|
||||
# logger.debug(
|
||||
# "Read writable code exceeded token limit, removing helper functions and only keeping function to optimize"
|
||||
# )
|
||||
# try:
|
||||
# read_writable_code = get_read_writable_code(og_code_containing_helpers, qualified_function_names)
|
||||
# except ValueError as e:
|
||||
# logger.debug(f"Error while getting read-writable code: {e}")
|
||||
# continue
|
||||
|
||||
tokenizer = tiktoken.encoding_for_model("gpt-4o")
|
||||
final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
|
||||
print("final_read_writable_tokens", final_read_writable_tokens)
|
||||
if final_read_writable_tokens > token_limit:
|
||||
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
|
||||
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_code_markdown.markdown))
|
||||
|
||||
total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens
|
||||
print("total_tokens", total_tokens)
|
||||
print(read_only_code_markdown.markdown)
|
||||
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
|
||||
if total_tokens <= token_limit:
|
||||
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
|
||||
|
||||
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
|
||||
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
|
||||
try:
|
||||
read_only_code = get_read_only_code(og_code_containing_helpers, qualified_function_names)
|
||||
except ValueError as e:
|
||||
logger.debug(f"Error while getting read-only code: {e}")
|
||||
continue
|
||||
|
||||
read_only_code_with_imports = CodeString(
|
||||
code=add_needed_imports_from_module(
|
||||
src_module_code=og_code_containing_helpers,
|
||||
dst_module_code=read_only_code,
|
||||
src_path=file_path,
|
||||
dst_path=file_path,
|
||||
project_root=project_root_path,
|
||||
helper_functions_fqn=qualified_function_names,
|
||||
),
|
||||
file_path=Path(file_path),
|
||||
)
|
||||
if read_only_code_with_imports.code:
|
||||
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
|
||||
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_code_markdown.markdown))
|
||||
total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens
|
||||
print("total_tokens after removal", total_tokens)
|
||||
if total_tokens <= token_limit:
|
||||
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
|
||||
|
||||
logger.debug("Code context has exceeded token limit, removing read-only code")
|
||||
return CodeString(code=final_read_writable_code).code, ""
|
||||
|
||||
|
||||
def is_dunder_method(name: str) -> bool:
|
||||
|
|
@ -132,6 +162,17 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
|
|||
return [sec for sec in possible_sections if hasattr(node, sec)]
|
||||
|
||||
|
||||
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
|
||||
"""Removes the docstring from a body of statements if present."""
|
||||
print(indented_block)
|
||||
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
|
||||
return indented_block
|
||||
first_stmt = indented_block.body[0].body[0]
|
||||
if isinstance(first_stmt, cst.Expr) and isinstance(first_stmt.value, cst.SimpleString):
|
||||
return indented_block.with_changes(body=indented_block.body[1:])
|
||||
return indented_block
|
||||
|
||||
|
||||
def prune_cst_for_read_writable_code(
|
||||
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
|
|
@ -238,6 +279,9 @@ def prune_cst_for_read_only_code(
|
|||
return None, True
|
||||
# Keep only dunder methods
|
||||
if is_dunder_method(node.name.value):
|
||||
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
new_body = remove_docstring_from_body(node.body)
|
||||
return node.with_changes(body=new_body), False
|
||||
return node, False
|
||||
return None, False
|
||||
|
||||
|
|
@ -256,7 +300,7 @@ def prune_cst_for_read_only_code(
|
|||
new_body = []
|
||||
for stmt in node.body.body:
|
||||
filtered, found_target = prune_cst_for_read_only_code(
|
||||
stmt, target_functions, class_prefix, remove_docstrings
|
||||
stmt, target_functions, class_prefix, remove_docstrings=remove_docstrings
|
||||
)
|
||||
found_in_class |= found_target
|
||||
|
||||
|
|
@ -271,6 +315,10 @@ def prune_cst_for_read_only_code(
|
|||
if not found_in_class:
|
||||
return None, False
|
||||
|
||||
if remove_docstrings:
|
||||
return node.with_changes(
|
||||
body=remove_docstring_from_body(node.body.with_changes(body=new_body))
|
||||
) if new_body else None, True
|
||||
return node.with_changes(body=node.body.with_changes(body=new_body)) if new_body else None, True
|
||||
|
||||
# For other nodes, keep the node and recursively filter children
|
||||
|
|
@ -288,7 +336,7 @@ def prune_cst_for_read_only_code(
|
|||
section_found_target = False
|
||||
for child in original_content:
|
||||
filtered, found_target = prune_cst_for_read_only_code(
|
||||
child, target_functions, prefix, remove_docstrings
|
||||
child, target_functions, prefix, remove_docstrings=remove_docstrings
|
||||
)
|
||||
if filtered:
|
||||
new_children.append(filtered)
|
||||
|
|
@ -299,7 +347,7 @@ def prune_cst_for_read_only_code(
|
|||
updates[section] = new_children
|
||||
elif original_content is not None:
|
||||
filtered, found_target = prune_cst_for_read_only_code(
|
||||
original_content, target_functions, prefix, remove_docstrings
|
||||
original_content, target_functions, prefix, remove_docstrings=remove_docstrings
|
||||
)
|
||||
found_any_target |= found_target
|
||||
if filtered:
|
||||
|
|
@ -316,7 +364,9 @@ def get_read_only_code(code: str, target_functions: set[str], remove_docstrings:
|
|||
class contextual information, and other module scoped variables.
|
||||
"""
|
||||
module = cst.parse_module(code)
|
||||
filtered_node, found_target = prune_cst_for_read_only_code(module, target_functions, remove_docstrings)
|
||||
filtered_node, found_target = prune_cst_for_read_only_code(
|
||||
module, target_functions, remove_docstrings=remove_docstrings
|
||||
)
|
||||
if not found_target:
|
||||
raise ValueError("No target functions found in the provided code")
|
||||
if filtered_node and isinstance(filtered_node, cst.Module):
|
||||
|
|
|
|||
|
|
@ -715,13 +715,17 @@ class Optimizer:
|
|||
contextual_dunder_methods.update(helper_dunder_methods)
|
||||
|
||||
# Will eventually refactor to use this function instead of the above
|
||||
read_writable_code, read_only_context_code = code_context_extractor.get_code_optimization_context(
|
||||
function_to_optimize, project_root
|
||||
)
|
||||
try:
|
||||
read_writable_code, read_only_context_code = code_context_extractor.get_code_optimization_context(
|
||||
function_to_optimize, project_root
|
||||
)
|
||||
except ValueError as e:
|
||||
return Failure(str(e))
|
||||
|
||||
logger.info("Read-writable code:")
|
||||
code_print(read_writable_code)
|
||||
logger.info("Read-only context code:")
|
||||
# code_print(read_only_context_code)
|
||||
code_print(read_only_context_code)
|
||||
return Success(
|
||||
CodeOptimizationContext(
|
||||
code_to_optimize_with_helpers=code_to_optimize_with_helpers_and_imports,
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ from collections import defaultdict
|
|||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
|
||||
class HelperClass:
|
||||
|
|
@ -609,3 +609,159 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
'''
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
||||
|
||||
def test_example_class() -> None:
|
||||
code = """
|
||||
class MyClass:
|
||||
\"\"\"A class with a helper method.\"\"\"
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
\"\"\"A helper class for MyClass.\"\"\"
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def __repr__(self):
|
||||
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
||||
return "HelperClass" + str(self.x)
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
file_path = Path(f.name).resolve()
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=file_path.parent.resolve(),
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path().resolve(),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize, opt.args.project_root
|
||||
)
|
||||
expected_read_write_context = """
|
||||
class MyClass:
|
||||
def target_method(self):
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{file_path}
|
||||
class MyClass:
|
||||
\"\"\"A class with a helper method.\"\"\"
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
class HelperClass:
|
||||
\"\"\"A helper class for MyClass.\"\"\"
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def __repr__(self):
|
||||
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
||||
return "HelperClass" + str(self.x)
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
||||
|
||||
def test_example_class_token_limit_1() -> None:
|
||||
docstring_filler = "This is a long docstring that will be used to fill up the token limit."
|
||||
docstring_filler_multiplied = " ".join([docstring_filler for _ in range(1000)])
|
||||
code = f"""
|
||||
class MyClass:
|
||||
\"\"\"A class with a helper method.
|
||||
{docstring_filler_multiplied}\"\"\"
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
def target_method(self):
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
\"\"\"A helper class for MyClass.\"\"\"
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def __repr__(self):
|
||||
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
||||
return "HelperClass" + str(self.x)
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode="w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
file_path = Path(f.name).resolve()
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=file_path.parent.resolve(),
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path().resolve(),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
read_write_context, read_only_context = get_code_optimization_context(
|
||||
function_to_optimize, opt.args.project_root
|
||||
)
|
||||
expected_read_write_context = """
|
||||
class MyClass:
|
||||
def target_method(self):
|
||||
y = HelperClass().helper_method()
|
||||
|
||||
class HelperClass:
|
||||
def helper_method(self):
|
||||
return self.x
|
||||
"""
|
||||
expected_read_only_context = f"""
|
||||
```python:{file_path}
|
||||
class MyClass:
|
||||
\"\"\"A class with a helper method.
|
||||
{docstring_filler_multiplied}\"\"\"
|
||||
def __init__(self):
|
||||
self.x = 1
|
||||
|
||||
class HelperClass:
|
||||
\"\"\"A helper class for MyClass.\"\"\"
|
||||
def __init__(self):
|
||||
\"\"\"Initialize the HelperClass.\"\"\"
|
||||
self.x = 1
|
||||
def __repr__(self):
|
||||
\"\"\"Return a string representation of the HelperClass.\"\"\"
|
||||
return "HelperClass" + str(self.x)
|
||||
```
|
||||
"""
|
||||
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
|
||||
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
|
||||
|
|
|
|||
|
|
@ -51,6 +51,90 @@ def test_dunder_methods() -> None:
|
|||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_dunder_methods_remove_docstring() -> None:
|
||||
code = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
\"\"\"Constructor for TestClass.\"\"\"
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
\"\"\"String representation of TestClass.\"\"\"
|
||||
return f"Value: {self.x}"
|
||||
|
||||
def target_method(self):
|
||||
print("stub me")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, remove_docstrings=True)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_class_remove_docstring() -> None:
|
||||
code = """
|
||||
class TestClass:
|
||||
\"\"\"Class docstring.\"\"\"
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
|
||||
def target_method(self):
|
||||
print("stub me")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, remove_docstrings=True)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_mixed_remove_docstring() -> None:
|
||||
code = """
|
||||
class TestClass:
|
||||
\"\"\"Class docstring.\"\"\"
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
\"\"\"String representation of TestClass.\"\"\"
|
||||
return f"Value: {self.x}"
|
||||
|
||||
def target_method(self):
|
||||
\"\"\"target method docstring.\"\"\"
|
||||
print("stub me")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.x = 42
|
||||
|
||||
def __str__(self):
|
||||
return f"Value: {self.x}"
|
||||
"""
|
||||
|
||||
output = get_read_only_code(dedent(code), {"TestClass.target_method"}, remove_docstrings=True)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_target_in_nested_class() -> None:
|
||||
"""Test that attempting to find a target in a nested class raises an error."""
|
||||
code = """
|
||||
|
|
@ -603,3 +687,114 @@ def test_simplified_complete_implementation() -> None:
|
|||
|
||||
output = get_read_only_code(dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"})
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
def test_simplified_complete_implementation_no_docstring() -> None:
|
||||
code = """
|
||||
class DataProcessor:
|
||||
\"\"\"A simple data processing class.\"\"\"
|
||||
|
||||
def __init__(self, data: Dict[str, Any]) -> None:
|
||||
self.data = data
|
||||
self._processed = False
|
||||
self.result = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DataProcessor(processed={self._processed})"
|
||||
|
||||
def target_method(self, key: str) -> Optional[Any]:
|
||||
\"\"\"Process and retrieve a specific key from the data.\"\"\"
|
||||
if not self._processed:
|
||||
self._process_data()
|
||||
return self.result.get(key) if self.result else None
|
||||
|
||||
def _process_data(self) -> None:
|
||||
\"\"\"Internal method to process the data.\"\"\"
|
||||
processed = {}
|
||||
for key, value in self.data.items():
|
||||
if isinstance(value, (int, float)):
|
||||
processed[key] = value * 2
|
||||
elif isinstance(value, str):
|
||||
processed[key] = value.upper()
|
||||
else:
|
||||
processed[key] = value
|
||||
self.result = processed
|
||||
self._processed = True
|
||||
|
||||
def to_json(self) -> str:
|
||||
\"\"\"Convert the processed data to JSON string.\"\"\"
|
||||
if not self._processed:
|
||||
self._process_data()
|
||||
return json.dumps(self.result)
|
||||
|
||||
try:
|
||||
sample_data = {"number": 42, "text": "hello"}
|
||||
processor = DataProcessor(sample_data)
|
||||
|
||||
class ResultHandler:
|
||||
def __init__(self, processor: DataProcessor):
|
||||
self.processor = processor
|
||||
self.cache = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ResultHandler(cache_size={len(self.cache)})"
|
||||
|
||||
def target_method(self, key: str) -> Optional[Any]:
|
||||
\"\"\"Retrieve and cache results for a key.\"\"\"
|
||||
if key not in self.cache:
|
||||
self.cache[key] = self.processor.target_method(key)
|
||||
return self.cache[key]
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
\"\"\"Clear the internal cache.\"\"\"
|
||||
self.cache.clear()
|
||||
|
||||
def get_stats(self) -> Dict[str, int]:
|
||||
\"\"\"Get cache statistics.\"\"\"
|
||||
return {
|
||||
"cache_size": len(self.cache),
|
||||
"hits": sum(1 for v in self.cache.values() if v is not None)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
class ResultHandler:
|
||||
def __init__(self):
|
||||
self.error = str(e)
|
||||
|
||||
def target_method(self, key: str) -> None:
|
||||
raise RuntimeError(f"Failed to initialize: {self.error}")
|
||||
"""
|
||||
|
||||
expected = """
|
||||
class DataProcessor:
|
||||
|
||||
def __init__(self, data: Dict[str, Any]) -> None:
|
||||
self.data = data
|
||||
self._processed = False
|
||||
self.result = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DataProcessor(processed={self._processed})"
|
||||
|
||||
try:
|
||||
sample_data = {"number": 42, "text": "hello"}
|
||||
processor = DataProcessor(sample_data)
|
||||
|
||||
class ResultHandler:
|
||||
def __init__(self, processor: DataProcessor):
|
||||
self.processor = processor
|
||||
self.cache = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ResultHandler(cache_size={len(self.cache)})"
|
||||
|
||||
except Exception as e:
|
||||
class ResultHandler:
|
||||
def __init__(self):
|
||||
self.error = str(e)
|
||||
"""
|
||||
|
||||
output = get_read_only_code(
|
||||
dedent(code), {"DataProcessor.target_method", "ResultHandler.target_method"}, remove_docstrings=True
|
||||
)
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
|
|
|||
|
|
@ -221,3 +221,5 @@ def test_module_var() -> None:
|
|||
|
||||
output = get_read_writable_code(dedent(code), {"target_function"})
|
||||
assert dedent(expected).strip() == output.strip()
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue