start using markdown representation for read writable context
This commit is contained in:
parent
553a192607
commit
216eb7e794
3 changed files with 26 additions and 10 deletions
|
|
@ -61,13 +61,14 @@ def get_code_optimization_context(
|
|||
)
|
||||
|
||||
# Extract code context for optimization
|
||||
final_read_writable_code = extract_code_string_context_from_files(
|
||||
final_read_writable_code = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
{},
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=False,
|
||||
code_context_type=CodeContextType.READ_WRITABLE,
|
||||
).code
|
||||
)
|
||||
|
||||
read_only_code_markdown = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
|
|
@ -84,14 +85,14 @@ def get_code_optimization_context(
|
|||
)
|
||||
|
||||
# Handle token limits
|
||||
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code)
|
||||
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.__str__)
|
||||
if final_read_writable_tokens > optim_token_limit:
|
||||
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
|
||||
|
||||
# Setup preexisting objects for code replacer
|
||||
preexisting_objects = set(
|
||||
chain(
|
||||
find_preexisting_objects(final_read_writable_code),
|
||||
find_preexisting_objects(final_read_writable_code.__str__),
|
||||
*(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings),
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from re import Pattern
|
|||
from typing import Annotated, Optional, cast
|
||||
|
||||
from jedi.api.classes import Name
|
||||
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
|
||||
from pydantic import AfterValidator, BaseModel, ConfigDict
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
|
|
@ -139,8 +139,22 @@ class CodeString(BaseModel):
|
|||
file_path: Optional[Path] = None
|
||||
|
||||
|
||||
def get_code_block_splitter(file_path: Path) -> str:
|
||||
return f"# codeflash-splitter__{file_path}"
|
||||
|
||||
|
||||
class CodeStringsMarkdown(BaseModel):
|
||||
code_strings: list[CodeString] = []
|
||||
cached_code: str | None = None
|
||||
|
||||
@property
|
||||
def __str__(self) -> str:
|
||||
if self.cached_code is not None:
|
||||
return self.cached_code
|
||||
self.cached_code = "\n\n".join(
|
||||
get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings
|
||||
)
|
||||
return self.cached_code
|
||||
|
||||
@property
|
||||
def markdown(self) -> str:
|
||||
|
|
@ -155,7 +169,7 @@ class CodeStringsMarkdown(BaseModel):
|
|||
|
||||
class CodeOptimizationContext(BaseModel):
|
||||
testgen_context_code: str = ""
|
||||
read_writable_code: str = Field(min_length=1)
|
||||
read_writable_code: CodeStringsMarkdown
|
||||
read_only_context_code: str = ""
|
||||
hashing_code_context: str = ""
|
||||
hashing_code_context_hash: str = ""
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from pathlib import Path
|
|||
import pytest
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.models.models import FunctionParent, get_code_block_splitter
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.code_utils.code_extractor import add_global_assignments
|
||||
|
|
@ -88,7 +88,8 @@ def test_code_replacement10() -> None:
|
|||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
|
||||
expected_read_write_context = """
|
||||
expected_read_write_context = f"""
|
||||
{get_code_block_splitter(file_path.relative_to(file_path.parent))}
|
||||
from __future__ import annotations
|
||||
|
||||
class HelperClass:
|
||||
|
|
@ -125,7 +126,7 @@ class MainClass:
|
|||
```
|
||||
"""
|
||||
|
||||
assert read_write_context.strip() == expected_read_write_context.strip()
|
||||
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
|
||||
assert read_only_context.strip() == expected_read_only_context.strip()
|
||||
assert hashing_context.strip() == expected_hashing_context.strip()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue