handle token limits

This commit is contained in:
Alvin Ryanputra 2024-12-26 14:29:32 -08:00
parent e10f13a83a
commit 0204ef2fcb
2 changed files with 145 additions and 24 deletions

View file

@ -108,43 +108,44 @@ def get_code_optimization_context(
if read_only_code_with_imports.code:
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
# Handle token limits
tokenizer = tiktoken.encoding_for_model("gpt-4o")
final_read_writable_tokens = len(tokenizer.encode(final_read_writable_code))
print("final_read_writable_tokens", final_read_writable_tokens)
if final_read_writable_tokens > token_limit:
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_code_markdown.markdown))
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_code_markdown.markdown))
total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens
print("total_tokens", total_tokens)
print(read_only_code_markdown.markdown)
if total_tokens <= token_limit:
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown
logger.debug("Code context has exceeded token limit, removing docstrings from read-only code")
# Get read-only code context again, this time without docstrings
read_only_code_markdown = CodeStringsMarkdown()
for file_path, qualified_function_names in file_path_to_qualified_function_names.items():
try:
read_only_code = get_read_only_code(og_code_containing_helpers, qualified_function_names)
read_only_code = get_read_only_code(
og_code_containing_helpers, qualified_function_names, remove_docstrings=True
)
except ValueError as e:
logger.debug(f"Error while getting read-only code: {e}")
continue
read_only_code_with_imports = CodeString(
code=add_needed_imports_from_module(
src_module_code=og_code_containing_helpers,
dst_module_code=read_only_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=qualified_function_names,
),
file_path=Path(file_path),
)
read_only_code_with_imports = CodeString(
code=add_needed_imports_from_module(
src_module_code=og_code_containing_helpers,
dst_module_code=read_only_code,
src_path=file_path,
dst_path=file_path,
project_root=project_root_path,
helper_functions_fqn=qualified_function_names,
),
file_path=Path(file_path),
)
if read_only_code_with_imports.code:
read_only_code_markdown.code_strings.append(read_only_code_with_imports)
read_only_code_markdown_tokens = len(tokenizer.encode(read_only_code_markdown.markdown))
total_tokens = final_read_writable_tokens + read_only_code_markdown_tokens
print("total_tokens after removal", total_tokens)
if total_tokens <= token_limit:
return CodeString(code=final_read_writable_code).code, read_only_code_markdown.markdown

View file

@ -6,6 +6,7 @@ from collections import defaultdict
from pathlib import Path
from textwrap import dedent
import pytest
from codeflash.context.code_context_extractor import get_code_optimization_context
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
@ -688,15 +689,17 @@ class HelperClass:
def test_example_class_token_limit_1() -> None:
docstring_filler = "This is a long docstring that will be used to fill up the token limit."
docstring_filler_multiplied = " ".join([docstring_filler for _ in range(1000)])
docstring_filler = " ".join(
["This is a long docstring that will be used to fill up the token limit." for _ in range(1000)]
)
code = f"""
class MyClass:
\"\"\"A class with a helper method.
{docstring_filler_multiplied}\"\"\"
{docstring_filler}\"\"\"
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
class HelperClass:
@ -736,9 +739,11 @@ class HelperClass:
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, opt.args.project_root
)
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
expected_read_write_context = """
class MyClass:
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
class HelperClass:
@ -748,11 +753,34 @@ class HelperClass:
expected_read_only_context = f"""
```python:{file_path}
class MyClass:
\"\"\"A class with a helper method.
{docstring_filler_multiplied}\"\"\"
def __init__(self):
self.x = 1
class HelperClass:
def __init__(self):
self.x = 1
def __repr__(self):
return "HelperClass" + str(self.x)
```
"""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
def test_example_class_token_limit_2() -> None:
string_filler = " ".join(
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
)
code = f"""
class MyClass:
\"\"\"A class with a helper method. \"\"\"
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
x = '{string_filler}'
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
@ -761,7 +789,99 @@ class HelperClass:
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
```
def helper_method(self):
return self.x
"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, opt.args.project_root
)
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
expected_read_write_context = """
class MyClass:
def target_method(self):
\"\"\"Docstring for target method\"\"\"
y = HelperClass().helper_method()
class HelperClass:
def helper_method(self):
return self.x
"""
expected_read_only_context = ""
assert read_write_context.strip() == dedent(expected_read_write_context).strip()
assert read_only_context.strip() == dedent(expected_read_only_context).strip()
def test_example_class_token_limit_3() -> None:
string_filler = " ".join(
["This is a long string that will be used to fill up the token limit." for _ in range(1000)]
)
code = f"""
class MyClass:
\"\"\"A class with a helper method. \"\"\"
def __init__(self):
self.x = 1
def target_method(self):
\"\"\"{string_filler}\"\"\"
y = HelperClass().helper_method()
class HelperClass:
\"\"\"A helper class for MyClass.\"\"\"
def __init__(self):
\"\"\"Initialize the HelperClass.\"\"\"
self.x = 1
def __repr__(self):
\"\"\"Return a string representation of the HelperClass.\"\"\"
return "HelperClass" + str(self.x)
def helper_method(self):
return self.x
"""
with tempfile.NamedTemporaryFile(mode="w") as f:
f.write(code)
f.flush()
file_path = Path(f.name).resolve()
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_method",
file_path=file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
# In this scenario, the read-writable code is too long, so we abort.
with pytest.raises(ValueError, match="Read-writable code has exceeded token limit, cannot proceed"):
read_write_context, read_only_context = get_code_optimization_context(
function_to_optimize, opt.args.project_root
)