unit tests fixing

This commit is contained in:
mohammed 2025-08-06 03:33:46 +03:00
parent 07a9365987
commit 989b1f30a2
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
9 changed files with 221 additions and 152 deletions

View file

@ -73,9 +73,9 @@ class AiServiceClient:
url = f"{self.base_url}/ai{endpoint}"
if method.upper() == "POST":
json_payload = json.dumps(payload, indent=None, default=pydantic_encoder)
print(f"========JSON PAYLOAD FOR {url}==============")
print(f"Payload: {json_payload}")
print("======================")
logger.debug(f"========JSON PAYLOAD FOR {url}==============")
logger.debug(json_payload)
logger.debug("======================")
headers = {**self.headers, "Content-Type": "application/json"}
response = requests.post(url, data=json_payload, headers=headers, timeout=timeout)
else:

View file

@ -19,7 +19,7 @@ if TYPE_CHECKING:
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, OptimizedCandidate, ValidCode
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, OptimizedCandidate, ValidCode
ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
@ -408,16 +408,17 @@ def replace_functions_and_add_imports(
def replace_function_definitions_in_module(
function_names: list[str],
optimized_code: str,
optimized_code: CodeStringsMarkdown,
module_abspath: Path,
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
project_root_path: Path,
) -> bool:
source_code: str = module_abspath.read_text(encoding="utf8")
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
new_code: str = replace_functions_and_add_imports(
add_global_assignments(optimized_code, source_code),
add_global_assignments(code_to_apply, source_code),
function_names,
optimized_code,
code_to_apply,
module_abspath,
preexisting_objects,
project_root_path,
@ -428,6 +429,19 @@ def replace_function_definitions_in_module(
return True
def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str:
file_to_code_context = optimized_code.file_to_path()
module_optimized_code = file_to_code_context.get(str(relative_path))
if module_optimized_code is None:
logger.warning(
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
"re-check your 'markdown code structure'"
f"existing files are {file_to_code_context.keys()}"
)
module_optimized_code = ""
return module_optimized_code
def is_zero_diff(original_code: str, new_code: str) -> bool:
return normalize_code(original_code) == normalize_code(new_code)

View file

@ -3,16 +3,14 @@ from __future__ import annotations
import ast
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from pathlib import Path
from pathlib import Path
from typing import TYPE_CHECKING, Optional
import libcst as cst
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.models.models import CodeString, CodeStringsMarkdown
if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@ -530,7 +528,11 @@ def revert_unused_helper_functions(
helper_names = [helper.qualified_name for helper in helpers_in_file]
reverted_code = replace_function_definitions_in_module(
function_names=helper_names,
optimized_code=original_code, # Use original code as the "optimized" code to revert
optimized_code=CodeStringsMarkdown(
code_strings=[
CodeString(code=original_code, file_path=Path(file_path).relative_to(project_root))
]
), # Use original code as the "optimized" code to revert
module_abspath=file_path,
preexisting_objects=set(), # Empty set since we're reverting
project_root_path=project_root,

View file

@ -174,6 +174,15 @@ class CodeStringsMarkdown(BaseModel):
@property
def flat(self) -> str:
"""Returns the combined Python module from all code blocks.
Each block is prefixed by a file path comment to indicate its origin.
This representation is syntactically valid Python code.
Returns:
str: The concatenated code of all blocks with file path annotations.
"""
if self._cache.get("flat") is not None:
return self._cache["flat"]
self._cache["flat"] = "\n".join(
@ -183,7 +192,15 @@ class CodeStringsMarkdown(BaseModel):
@property
def markdown(self) -> str:
"""Returns the markdown representation of the code, including the file path where possible."""
"""Returns a Markdown-formatted string containing all code blocks.
Each block is enclosed in a triple-backtick code block with an optional
file path suffix (e.g., ```python:filename.py).
Returns:
str: Markdown representation of the code blocks.
"""
return "\n".join(
[
f"```python{':' + str(code_string.file_path) if code_string.file_path else ''}\n{code_string.code.strip()}\n```"
@ -192,6 +209,12 @@ class CodeStringsMarkdown(BaseModel):
)
def file_to_path(self) -> dict[str, str]:
"""Return a dictionary mapping file paths to their corresponding code blocks.
Returns:
dict[str, str]: Mapping from file path (as string) to code.
"""
if self._cache.get("file_to_path") is not None:
return self._cache["file_to_path"]
self._cache["file_to_path"] = {
@ -201,6 +224,17 @@ class CodeStringsMarkdown(BaseModel):
@staticmethod
def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown:
"""Parse a Markdown string into a CodeStringsMarkdown object.
Extracts code blocks and their associated file paths and constructs a new CodeStringsMarkdown instance.
Args:
markdown_code (str): The Markdown-formatted string to parse.
Returns:
CodeStringsMarkdown: Parsed object containing code blocks.
"""
matches = markdown_pattern.findall(markdown_code)
results = CodeStringsMarkdown()
for file_path, code in matches:

View file

@ -720,29 +720,13 @@ class FunctionOptimizer:
read_writable_functions_by_file_path[self.function_to_optimize.file_path].add(
self.function_to_optimize.qualified_name
)
file_to_code_context = optimized_code.file_to_path()
for helper_function in code_context.helper_functions:
if helper_function.jedi_definition.type != "class":
read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name)
for module_abspath, qualified_names in read_writable_functions_by_file_path.items():
relative_module_path = str(module_abspath.relative_to(self.project_root))
logger.debug(f"applying optimized code to: {relative_module_path}")
scoped_optimized_code = file_to_code_context.get(relative_module_path)
if scoped_optimized_code is None:
logger.warning(
f"Optimized code not found for {relative_module_path} In the context\n-------\n{optimized_code}\n-------\n"
"re-check your 'split markers'"
f"existing files are {file_to_code_context.keys()}"
)
scoped_optimized_code = ""
did_update |= replace_function_definitions_in_module(
function_names=list(qualified_names),
optimized_code=scoped_optimized_code,
optimized_code=optimized_code,
module_abspath=module_abspath,
preexisting_objects=code_context.preexisting_objects,
project_root_path=self.project_root,

View file

@ -89,7 +89,7 @@ def test_code_replacement10() -> None:
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
{get_code_block_splitter(file_path.relative_to(file_path.parent))}
```python:{file_path.relative_to(file_path.parent)}
from __future__ import annotations
class HelperClass:
@ -107,6 +107,7 @@ class MainClass:
def main_method(self):
self.name = HelperClass.NestedClass("test").nested_method()
return HelperClass(self.name).helper_method()
```
"""
expected_read_only_context = """
"""
@ -126,7 +127,7 @@ class MainClass:
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -147,7 +148,7 @@ def test_class_method_dependencies() -> None:
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
{get_code_block_splitter(file_path.relative_to(file_path.parent))}
```python:{file_path.relative_to(file_path.parent)}
from __future__ import annotations
from collections import defaultdict
@ -175,7 +176,7 @@ class Graph:
# Print contents of stack
return stack
```
"""
expected_read_only_context = ""
@ -200,7 +201,7 @@ class Graph:
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -227,7 +228,7 @@ def test_bubble_sort_helper() -> None:
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
{get_code_block_splitter("code_to_optimize/code_directories/retriever/bubble_sort_with_math.py")}
```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py
import math
def sorter(arr):
@ -235,14 +236,14 @@ def sorter(arr):
x = math.sqrt(2)
print(x)
return arr
{get_code_block_splitter("code_to_optimize/code_directories/retriever/bubble_sort_imported.py")}
```
```python:code_to_optimize/code_directories/retriever/bubble_sort_imported.py
from bubble_sort_with_math import sorter
def sort_from_another_file(arr):
sorted_arr = sorter(arr)
return sorted_arr
```
"""
expected_read_only_context = ""
@ -260,7 +261,7 @@ def sort_from_another_file(arr):
return sorted_arr
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -458,7 +459,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
```python:{file_path.relative_to(opt.args.project_root)}
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
def __init__(self) -> None: ...
@ -553,7 +554,8 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
kwargs=kwargs,
lifespan=self.__duration__,
)
"""
```
"""
expected_read_only_context = f'''
```python:{file_path.relative_to(opt.args.project_root)}
_P = ParamSpec("_P")
@ -647,7 +649,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__)
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -700,7 +702,7 @@ class HelperClass:
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def __init__(self):
self.x = 1
@ -713,7 +715,8 @@ class HelperClass:
self.x = 1
def helper_method(self):
return self.x
"""
```
"""
expected_read_only_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
@ -740,7 +743,7 @@ class HelperClass:
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -798,7 +801,7 @@ class HelperClass:
hashing_context = code_ctx.hashing_code_context
# In this scenario, the read-only code context is too long, so the read-only docstrings are removed.
expected_read_write_context = f"""
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def __init__(self):
self.x = 1
@ -812,7 +815,8 @@ class HelperClass:
self.x = 1
def helper_method(self):
return self.x
"""
```
"""
expected_read_only_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
@ -836,7 +840,7 @@ class HelperClass:
return self.x
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -894,7 +898,7 @@ class HelperClass:
hashing_context = code_ctx.hashing_code_context
# In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely.
expected_read_write_context = f"""
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def __init__(self):
self.x = 1
@ -908,7 +912,8 @@ class HelperClass:
self.x = 1
def helper_method(self):
return self.x
"""
```
"""
expected_read_only_context = ""
expected_hashing_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
@ -923,7 +928,7 @@ class HelperClass:
return self.x
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1048,7 +1053,7 @@ def test_repo_helper() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
```python:{path_to_utils.relative_to(project_root)}
import math
class DataProcessor:
@ -1065,8 +1070,8 @@ class DataProcessor:
def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str:
\"\"\"Add a prefix to the processed data.\"\"\"
return prefix + data
{get_code_block_splitter(path_to_file.relative_to(project_root))}
```
```python:{path_to_file.relative_to(project_root)}
import requests
from globals import API_URL
from utils import DataProcessor
@ -1084,6 +1089,7 @@ def fetch_and_process_data():
processed = processor.add_prefix(processed)
return processed
```
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
@ -1118,7 +1124,7 @@ def fetch_and_process_data():
return processed
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1140,7 +1146,7 @@ def test_repo_helper_of_helper() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
```python:{path_to_utils.relative_to(project_root)}
import math
from transform_utils import DataTransformer
@ -1158,8 +1164,8 @@ class DataProcessor:
def transform_data(self, data: str) -> str:
\"\"\"Transform the processed data\"\"\"
return DataTransformer().transform(data)
{get_code_block_splitter(path_to_file.relative_to(project_root))}
```
```python:{path_to_file.relative_to(project_root)}
import requests
from globals import API_URL
from utils import DataProcessor
@ -1176,8 +1182,8 @@ def fetch_and_transform_data():
transformed = processor.transform_data(processed)
return transformed
"""
```
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
@ -1217,7 +1223,7 @@ def fetch_and_transform_data():
return transformed
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1238,15 +1244,15 @@ def test_repo_helper_of_helper_same_class() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def __init__(self):
self.data = None
def transform_using_own_method(self, data):
return self.transform(data)
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
```
```python:{path_to_utils.relative_to(project_root)}
import math
from transform_utils import DataTransformer
@ -1260,7 +1266,7 @@ class DataProcessor:
def transform_data_own_method(self, data: str) -> str:
\"\"\"Transform the processed data using own method\"\"\"
return DataTransformer().transform_using_own_method(data)
```
"""
expected_read_only_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
@ -1297,7 +1303,7 @@ class DataProcessor:
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1318,15 +1324,15 @@ def test_repo_helper_of_helper_same_file() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def __init__(self):
self.data = None
def transform_using_same_file_function(self, data):
return update_data(data)
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
```
```python:{path_to_utils.relative_to(project_root)}
import math
from transform_utils import DataTransformer
@ -1340,7 +1346,8 @@ class DataProcessor:
def transform_data_same_file_function(self, data: str) -> str:
\"\"\"Transform the processed data using a function from the same file\"\"\"
return DataTransformer().transform_using_same_file_function(data)
"""
```
"""
expected_read_only_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
def update_data(data):
@ -1372,7 +1379,7 @@ class DataProcessor:
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1392,7 +1399,7 @@ def test_repo_helper_all_same_file() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def __init__(self):
self.data = None
@ -1407,7 +1414,8 @@ class DataTransformer:
def update_data(data):
return data + " updated"
"""
```
"""
expected_read_only_context = f"""
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
@ -1434,7 +1442,7 @@ def update_data(data):
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1455,7 +1463,7 @@ def test_repo_helper_circular_dependency() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
{get_code_block_splitter(path_to_utils.relative_to(project_root))}
```python:{path_to_utils.relative_to(project_root)}
import math
from transform_utils import DataTransformer
@ -1469,8 +1477,8 @@ class DataProcessor:
def circular_dependency(self, data: str) -> str:
\"\"\"Test circular dependency\"\"\"
return DataTransformer().circular_dependency(data)
{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))}
```
```python:{path_to_transform_utils.relative_to(project_root)}
from code_to_optimize.code_directories.retriever.utils import DataProcessor
class DataTransformer:
@ -1479,9 +1487,8 @@ class DataTransformer:
def circular_dependency(self, data):
return DataProcessor().circular_dependency(data)
"""
```
"""
expected_read_only_context = f"""
```python:{path_to_utils.relative_to(project_root)}
class DataProcessor:
@ -1510,7 +1517,7 @@ class DataTransformer:
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1554,13 +1561,14 @@ def outside_method():
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
{get_code_block_splitter(file_path.relative_to(opt.args.project_root))}
```python:{file_path.relative_to(opt.args.project_root)}
class MyClass:
def __init__(self):
self.x = 1
self.y = outside_method()
def target_method(self):
return self.x + self.y
```
"""
expected_read_only_context = f"""
```python:{file_path.relative_to(opt.args.project_root)}
@ -1576,7 +1584,7 @@ class MyClass:
return self.x + self.y
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1634,7 +1642,7 @@ def function_to_optimize():
```
"""
expected_read_write_context = f"""
{get_code_block_splitter(path_to_main.relative_to(project_root))}
```python:{path_to_main.relative_to(project_root)}
import requests
from globals import API_URL
from utils import DataProcessor
@ -1651,14 +1659,15 @@ def fetch_and_transform_data():
transformed = processor.transform_data(processed)
return transformed
{get_code_block_splitter(path_to_fto.relative_to(project_root))}
```
```python:{path_to_fto.relative_to(project_root)}
import code_to_optimize.code_directories.retriever.main
def function_to_optimize():
return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data()
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -1818,7 +1827,7 @@ def get_system_details():
hashing_context = code_ctx.hashing_code_context
# The expected contexts
expected_read_write_context = f"""
{get_code_block_splitter(main_file_path.relative_to(opt.args.project_root))}
```python:{main_file_path.relative_to(opt.args.project_root)}
import utility_module
class Calculator:
@ -1845,6 +1854,7 @@ class Calculator:
return self.subtract(x, y)
else:
return None
```
"""
expected_read_only_context = """
```python:utility_module.py
@ -1902,7 +1912,7 @@ class Calculator:
```
"""
# Verify the contexts match the expected values
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()
assert hashing_context.strip() == expected_hashing_context.strip()
@ -2061,7 +2071,7 @@ def get_system_details():
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
# The expected contexts
expected_read_write_context = f"""
{get_code_block_splitter("utility_module.py")}
```python:utility_module.py
# Function that will be used in the main code
def select_precision(precision, fallback_precision):
@ -2085,8 +2095,8 @@ def select_precision(precision, fallback_precision):
return precision.lower()
else:
return DEFAULT_PRECISION
{get_code_block_splitter(main_file_path.relative_to(opt.args.project_root))}
```
```python:{main_file_path.relative_to(opt.args.project_root)}
import utility_module
class Calculator:
@ -2099,6 +2109,7 @@ class Calculator:
self.backend = utility_module.CALCULATION_BACKEND
self.system = utility_module.SYSTEM_TYPE
self.default_precision = utility_module.DEFAULT_PRECISION
```
"""
expected_read_only_context = """
```python:utility_module.py
@ -2113,7 +2124,7 @@ except ImportError:
CALCULATION_BACKEND = "python"
```
"""
assert read_write_context.flat.strip() == expected_read_write_context.strip()
assert read_write_context.markdown.strip() == expected_read_write_context.strip()
assert read_only_context.strip() == expected_read_only_context.strip()

View file

@ -13,7 +13,7 @@ from codeflash.code_utils.code_replacer import (
replace_functions_in_file,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent, get_code_block_splitter
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -43,12 +43,14 @@ class Args:
def test_code_replacement_global_statements():
project_root = Path(__file__).parent.parent.resolve()
code_path = (project_root / "code_to_optimize/bubble_sort_optimized.py").resolve()
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(project_root))}
optimized_code = f"""```python:{code_path.relative_to(project_root)}
import numpy as np
inconsequential_var = '123'
def sorter(arr):
return arr.sort()"""
return arr.sort()
```
"""
original_code_str = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").read_text(
encoding="utf-8"
)
@ -1684,7 +1686,7 @@ class NewClass:
def new_function2(value):
return cst.ensure_type(value, str)
"""
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
optimized_code = f"""```python:{code_path.relative_to(root_dir)}
import numpy as np
def some_fn():
@ -1699,7 +1701,8 @@ class NewClass:
return cst.ensure_type(value, str)
a=2
print("Hello world")
"""
```
"""
expected_code = """import numpy as np
print("Hello world")
@ -1760,7 +1763,7 @@ class NewClass:
return cst.ensure_type(value, str)
a=1
"""
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
optimized_code = f"""```python:{code_path.relative_to(root_dir)}
a=2
import numpy as np
def some_fn():
@ -1774,7 +1777,8 @@ class NewClass:
def new_function2(value):
return cst.ensure_type(value, str)
print("Hello world")
"""
```
"""
expected_code = """import numpy as np
print("Hello world")
@ -1837,7 +1841,7 @@ class NewClass:
def new_function2(value):
return cst.ensure_type(value, str)
"""
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
optimized_code = f"""```python:{code_path.relative_to(root_dir)}
import numpy as np
a=2
def some_fn():
@ -1852,7 +1856,8 @@ class NewClass:
return cst.ensure_type(value, str)
a=3
print("Hello world")
"""
```
"""
expected_code = """import numpy as np
print("Hello world")
@ -1915,7 +1920,7 @@ class NewClass:
def new_function2(value):
return cst.ensure_type(value, str)
"""
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
optimized_code = f"""```python:{code_path.relative_to(root_dir)}
a=2
import numpy as np
def some_fn():
@ -1929,7 +1934,8 @@ class NewClass:
def new_function2(value):
return cst.ensure_type(value, str)
print("Hello world")
"""
```
"""
expected_code = """import numpy as np
print("Hello world")
@ -1992,7 +1998,7 @@ class NewClass:
def new_function2(value):
return cst.ensure_type(value, str)
"""
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
optimized_code = f"""```python:{code_path.relative_to(root_dir)}
import numpy as np
a=2
def some_fn():
@ -2007,7 +2013,8 @@ class NewClass:
return cst.ensure_type(value, str)
a=3
print("Hello world")
"""
```
"""
expected_code = """import numpy as np
print("Hello world")
@ -2073,7 +2080,7 @@ class NewClass:
def new_function2(value):
return cst.ensure_type(value, str)
"""
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
optimized_code = f"""```python:{code_path.relative_to(root_dir)}
import numpy as np
if 1<2:
a=2
@ -2091,6 +2098,7 @@ class NewClass:
def new_function2(value):
return cst.ensure_type(value, str)
print("Hello world")
```
"""
expected_code = """import numpy as np

View file

@ -1,6 +1,6 @@
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, get_code_block_splitter
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -50,7 +50,7 @@ def _get_string_usage(text: str) -> Usage:
"""
main_file.write_text(original_main, encoding="utf-8")
optimized_code = f"""{get_code_block_splitter(helper_file.relative_to(root_dir))}
optimized_code = f"""```python:{helper_file.relative_to(root_dir)}
import re
from collections.abc import Sequence
@ -83,14 +83,15 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
tokens += len(part.data)
return tokens
{get_code_block_splitter(main_file.relative_to(root_dir))}
```
```python:{main_file.relative_to(root_dir)}
from temp_helper import _estimate_string_tokens
from pydantic_ai_slim.pydantic_ai.usage import Usage
def _get_string_usage(text: str) -> Usage:
response_tokens = _estimate_string_tokens(text)
return Usage(response_tokens=response_tokens, total_tokens=response_tokens)
```
"""

View file

@ -6,7 +6,7 @@ from pathlib import Path
import pytest
from codeflash.context.unused_definition_remover import detect_unused_helper_functions
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeStringsMarkdown, get_code_block_splitter
from codeflash.models.models import CodeStringsMarkdown
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -56,8 +56,8 @@ def test_detect_unused_helper_functions(temp_project):
temp_dir, main_file, test_cfg = temp_project
# Optimized version that only calls one helper
optimized_code = f"""
{get_code_block_splitter("main.py")}
optimized_code = """
```python:main.py
def entrypoint_function(n):
\"\"\"Optimized function that only calls one helper.\"\"\"
result1 = helper_function_1(n)
@ -70,6 +70,7 @@ def helper_function_1(x):
def helper_function_2(x):
\"\"\"Second helper function - MODIFIED VERSION should be reverted.\"\"\"
return x * 4 # This change should be reverted to original x * 3
```
"""
# Create FunctionToOptimize instance
@ -91,7 +92,7 @@ def helper_function_2(x):
code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code)
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat)
# Should detect helper_function_2 as unused
unused_names = {uh.qualified_name for uh in unused_helpers}
@ -101,8 +102,8 @@ def helper_function_2(x):
# Also test the complete replace_function_and_helpers_with_optimized_code workflow
# First modify the optimized code to include a MODIFIED unused helper
optimized_code_with_modified_helper = f"""
{get_code_block_splitter("main.py")}
optimized_code_with_modified_helper = """
```python:main.py
def entrypoint_function(n):
\"\"\"Optimized function that only calls one helper.\"\"\"
result1 = helper_function_1(n)
@ -115,6 +116,7 @@ def helper_function_1(x):
def helper_function_2(x):
\"\"\"Second helper function - MODIFIED VERSION should be reverted.\"\"\"
return x * 7 # This should be reverted to x * 3
```
"""
original_helper_code = {main_file: main_file.read_text()}
@ -161,8 +163,8 @@ def test_revert_unused_helper_functions(temp_project):
temp_dir, main_file, test_cfg = temp_project
# Optimized version that only calls one helper and modifies the unused one
optimized_code = f"""
{get_code_block_splitter("main.py") }
optimized_code = """
```python:main.py
def entrypoint_function(n):
\"\"\"Optimized function that only calls one helper.\"\"\"
result1 = helper_function_1(n)
@ -175,6 +177,7 @@ def helper_function_1(x):
def helper_function_2(x):
\"\"\"Modified helper function - should be reverted.\"\"\"
return x * 4 # This change should be reverted
```
"""
# Create FunctionToOptimize instance
@ -224,8 +227,8 @@ def test_no_unused_helpers_no_revert(temp_project):
temp_dir, main_file, test_cfg = temp_project
# Optimized version that still calls both helpers
optimized_code = f"""
{get_code_block_splitter("main.py")}
optimized_code = """
```python:main.py
def entrypoint_function(n):
\"\"\"Optimized function that still calls both helpers.\"\"\"
result1 = helper_function_1(n)
@ -239,6 +242,7 @@ def helper_function_1(x):
def helper_function_2(x):
\"\"\"Second helper function - optimized.\"\"\"
return x * 3
```
"""
# Create FunctionToOptimize instance
@ -263,7 +267,7 @@ def helper_function_2(x):
original_helper_code = {main_file: main_file.read_text()}
# Test detection - should find no unused helpers
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code)
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat)
assert len(unused_helpers) == 0, "No helpers should be detected as unused"
# Apply optimization
@ -307,14 +311,15 @@ def helper_function_2(x):
""")
# Optimized version that only calls one helper
optimized_code = f"""
{get_code_block_splitter("main.py")}
optimized_code = """
```python:main.py
from helpers import helper_function_1
def entrypoint_function(n):
\"\"\"Optimized function that only calls one helper.\"\"\"
result1 = helper_function_1(n)
return result1 + n * 3 # Inlined helper_function_2
```
"""
# Create test config
@ -345,7 +350,7 @@ def entrypoint_function(n):
code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code)
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat)
# Should detect helper_function_2 as unused
unused_names = {uh.qualified_name for uh in unused_helpers}
@ -482,8 +487,8 @@ class Calculator:
""")
# Optimized version that only calls one helper method
optimized_code = f"""
{get_code_block_splitter("main.py") }
optimized_code = """
```python:main.py
class Calculator:
def entrypoint_method(self, n):
\"\"\"Optimized method that only calls one helper.\"\"\"
@ -497,6 +502,7 @@ class Calculator:
def helper_method_2(self, x):
\"\"\"Second helper method - should be reverted.\"\"\"
return x * 4
```
"""
# Create test config
@ -532,7 +538,7 @@ class Calculator:
code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code)
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat)
# Should detect Calculator.helper_method_2 as unused
unused_names = {uh.qualified_name for uh in unused_helpers}
@ -542,8 +548,8 @@ class Calculator:
# Also test the complete replace_function_and_helpers_with_optimized_code workflow
# Update optimized code to include a MODIFIED unused helper
optimized_code_with_modified_helper = f"""
{get_code_block_splitter("main.py")}
optimized_code_with_modified_helper = """
```python:main.py
class Calculator:
def entrypoint_method(self, n):
\"\"\"Optimized method that only calls one helper.\"\"\"
@ -557,6 +563,7 @@ class Calculator:
def helper_method_2(self, x):
\"\"\"Second helper method - MODIFIED VERSION should be reverted.\"\"\"
return x * 8 # This should be reverted to x * 3
```
"""
original_helper_code = {main_file: main_file.read_text()}
@ -625,8 +632,8 @@ class Processor:
""")
# Optimized version that only calls one external helper
optimized_code = f"""
{get_code_block_splitter("main.py") }
optimized_code = """
```python:main.py
def external_helper_1(x):
\"\"\"External helper function.\"\"\"
return x * 2
@ -640,6 +647,7 @@ class Processor:
\"\"\"Optimized method that only calls one helper.\"\"\"
result1 = external_helper_1(n)
return result1 + n * 3 # Inlined external_helper_2
```
"""
# Create test config
@ -675,7 +683,7 @@ class Processor:
code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code)
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat)
# Should detect external_helper_2 as unused
unused_names = {uh.qualified_name for uh in unused_helpers}
@ -685,8 +693,8 @@ class Processor:
# Also test the complete replace_function_and_helpers_with_optimized_code workflow
# Update optimized code to include a MODIFIED unused helper
optimized_code_with_modified_helper = f"""
{get_code_block_splitter("main.py")}
optimized_code_with_modified_helper = """
```python:main.py
def external_helper_1(x):
\"\"\"External helper function.\"\"\"
return x * 2
@ -700,6 +708,7 @@ class Processor:
\"\"\"Optimized method that only calls one helper.\"\"\"
result1 = external_helper_1(n)
return result1 + n * 3 # Inlined external_helper_2
```
"""
original_helper_code = {main_file: main_file.read_text()}
@ -724,8 +733,8 @@ class Processor:
# Also test the complete replace_function_and_helpers_with_optimized_code workflow
# Update optimized code to include a MODIFIED unused helper
optimized_code_with_modified_helper = f"""
{get_code_block_splitter("main.py")}
optimized_code_with_modified_helper = """
```python:main.py
def external_helper_1(x):
\"\"\"External helper function.\"\"\"
return x * 2
@ -739,6 +748,7 @@ class Processor:
\"\"\"Optimized method that only calls one helper.\"\"\"
result1 = external_helper_1(n)
return result1 + n * 3 # Inlined external_helper_2
```
"""
original_helper_code = {main_file: main_file.read_text()}
@ -795,8 +805,8 @@ class OuterClass:
""")
# Optimized version that inlines one helper
optimized_code = f"""
{get_code_block_splitter("main.py") }
optimized_code = """
```python:main.py
def global_helper_1(x):
return x * 2
@ -812,6 +822,7 @@ class OuterClass:
def local_helper(self, x):
return x + 1
```
"""
# Create test config
@ -878,7 +889,7 @@ class OuterClass:
]
},
)(),
optimized_code,
CodeStringsMarkdown.parse_markdown_code(optimized_code).flat,
)
# Should detect global_helper_2 as unused
@ -964,8 +975,8 @@ def clean_data(x):
""")
# Optimized version that only uses some functions
optimized_code = f"""
{get_code_block_splitter("main.py") }
optimized_code = """
```python:main.py
import utils
from math_helpers import add
@ -976,6 +987,7 @@ def entrypoint_function(n):
# Inlined multiply: result3 = n * 2
# Inlined process_data: result4 = n ** 2
return result1 + result2 + (n * 2) + (n ** 2)
```
"""
# Create test config
@ -1006,7 +1018,7 @@ def entrypoint_function(n):
code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code)
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat)
# Should detect multiply, process_data as unused (at minimum)
unused_names = {uh.qualified_name for uh in unused_helpers}
@ -1126,8 +1138,8 @@ def divide_numbers(x, y):
""")
# Optimized version that only uses add_numbers
optimized_code = f"""
{get_code_block_splitter("main.py") }
optimized_code = """
```python:main.py
import calculator
def entrypoint_function(n):
@ -1135,6 +1147,7 @@ def entrypoint_function(n):
result1 = calculator.add_numbers(n, 10)
# Inlined: result2 = n * 5
return result1 + (n * 5)
```
"""
# Create test config
@ -1165,7 +1178,7 @@ def entrypoint_function(n):
code_context = ctx_result.unwrap()
# Test unused helper detection
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code)
unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat)
# Should detect multiply_numbers and divide_numbers as unused
unused_names = {uh.qualified_name for uh in unused_helpers}
@ -1329,8 +1342,8 @@ class MathUtils:
""")
# Optimized static method that inlines one utility
optimized_static_code = f"""
{get_code_block_splitter("main.py")}
optimized_static_code = """
```python:main.py
def utility_function_1(x):
return x * 2
@ -1350,6 +1363,7 @@ class MathUtils:
result1 = utility_function_1(n)
result2 = utility_function_2(n)
return result1 - result2
```
"""
# Create test config
@ -1386,7 +1400,7 @@ class MathUtils:
# Test unused helper detection for static method
unused_helpers = detect_unused_helper_functions(
optimizer.function_to_optimize, code_context, optimized_static_code
optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_static_code).flat
)
# Should detect utility_function_2 as unused
@ -1397,8 +1411,8 @@ class MathUtils:
# Also test the complete replace_function_and_helpers_with_optimized_code workflow
# Update optimized code to include a MODIFIED unused helper
optimized_static_code_with_modified_helper = f"""
{get_code_block_splitter("main.py")}
optimized_static_code_with_modified_helper = """
```python:main.py
def utility_function_1(x):
return x * 2
@ -1418,6 +1432,7 @@ class MathUtils:
result1 = utility_function_1(n)
result2 = utility_function_2(n)
return result1 - result2
```
"""
original_helper_code = {main_file: main_file.read_text()}