start using markdown representation for read writable context

This commit is contained in:
mohammed 2025-07-16 18:48:31 +03:00
parent 553a192607
commit 216eb7e794
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
3 changed files with 26 additions and 10 deletions

View file

@ -61,13 +61,14 @@ def get_code_optimization_context(
)
# Extract code context for optimization
final_read_writable_code = extract_code_string_context_from_files(
final_read_writable_code = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
{},
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=False,
code_context_type=CodeContextType.READ_WRITABLE,
).code
)
read_only_code_markdown = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
@ -84,14 +85,14 @@ def get_code_optimization_context(
)
# Handle token limits
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code)
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.__str__)
if final_read_writable_tokens > optim_token_limit:
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
# Setup preexisting objects for code replacer
preexisting_objects = set(
chain(
find_preexisting_objects(final_read_writable_code),
find_preexisting_objects(final_read_writable_code.__str__),
*(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings),
)
)

View file

@ -19,7 +19,7 @@ from re import Pattern
from typing import Annotated, Optional, cast
from jedi.api.classes import Name
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
from pydantic import AfterValidator, BaseModel, ConfigDict
from pydantic.dataclasses import dataclass
from codeflash.cli_cmds.console import console, logger
@ -139,8 +139,22 @@ class CodeString(BaseModel):
file_path: Optional[Path] = None
def get_code_block_splitter(file_path: Path) -> str:
return f"# codeflash-splitter__{file_path}"
class CodeStringsMarkdown(BaseModel):
code_strings: list[CodeString] = []
cached_code: str | None = None
@property
def __str__(self) -> str:
if self.cached_code is not None:
return self.cached_code
self.cached_code = "\n\n".join(
get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings
)
return self.cached_code
@property
def markdown(self) -> str:
@ -155,7 +169,7 @@ class CodeStringsMarkdown(BaseModel):
class CodeOptimizationContext(BaseModel):
testgen_context_code: str = ""
read_writable_code: str = Field(min_length=1)
read_writable_code: CodeStringsMarkdown
read_only_context_code: str = ""
hashing_code_context: str = ""
hashing_code_context_hash: str = ""

View file

@ -9,7 +9,7 @@ from pathlib import Path
import pytest
from codeflash.context.code_context_extractor import get_code_optimization_context
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.models.models import FunctionParent, get_code_block_splitter
from codeflash.optimization.optimizer import Optimizer
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.code_utils.code_extractor import add_global_assignments
@ -88,7 +88,8 @@ def test_code_replacement10() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = """
expected_read_write_context = f"""
{get_code_block_splitter(file_path.relative_to(file_path.parent))}
from __future__ import annotations
class HelperClass:
@ -125,7 +126,7 @@ class MainClass:
```
"""
assert read_write_context.strip() == expected_read_write_context.strip()
assert read_write_context.__str__.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()