fix code replacement tests

This commit is contained in:
mohammed 2025-07-25 15:39:47 +03:00
parent 99cd9dc706
commit 330bf91e73
No known key found for this signature in database
GPG key ID: 44F9B42770617B9B
3 changed files with 35 additions and 16 deletions

View file

@ -169,8 +169,11 @@ class CodeStringsMarkdown(BaseModel):
]
)
def path_to_code_string(self) -> dict[str, str]:
return {code_string.file_path: code_string.code for code_string in self.code_strings}
@staticmethod
def from_str_with_markers(code_with_markers: str) -> list[CodeString]:
def from_str_with_markers(code_with_markers: str) -> CodeStringsMarkdown:
pattern = rf"{SPLITTER_MARKER}([^\n]+)\n"
matches = list(re.finditer(pattern, code_with_markers))
@ -181,7 +184,7 @@ class CodeStringsMarkdown(BaseModel):
file_path = match.group(1).strip()
code = code_with_markers[start:end].lstrip("\n")
results.append(CodeString(file_path=file_path, code=code))
return results
return CodeStringsMarkdown(code_strings=results)
class CodeOptimizationContext(BaseModel):

View file

@ -621,18 +621,22 @@ class FunctionOptimizer:
read_writable_functions_by_file_path[self.function_to_optimize.file_path].add(
self.function_to_optimize.qualified_name
)
code_strings = CodeStringsMarkdown.from_str_with_markers(optimized_code)
optimized_code_dict = {code_string.file_path: code_string.code for code_string in code_strings}
logger.debug(f"Optimized code: {optimized_code_dict}")
file_to_code_context = CodeStringsMarkdown.from_str_with_markers(optimized_code).path_to_code_string()
logger.debug(f"Optimized code: {file_to_code_context}")
for helper_function in code_context.helper_functions:
if helper_function.jedi_definition.type != "class":
read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name)
for module_abspath, qualified_names in read_writable_functions_by_file_path.items():
relative_module_path = module_abspath.relative_to(self.project_root)
logger.debug(f"applying optimized code to: {relative_module_path}")
optimized_code = file_to_code_context.get(relative_module_path)
if not optimized_code:
msg = f"Optimized code not found for {relative_module_path}, existing files in the context are: {list(file_to_code_context.keys())}, re-check your 'split markers'"
raise ValueError(msg)
did_update |= replace_function_definitions_in_module(
function_names=list(qualified_names),
optimized_code=optimized_code_dict.get(relative_module_path),
optimized_code=optimized_code,
module_abspath=module_abspath,
preexisting_objects=code_context.preexisting_objects,
project_root_path=self.project_root,

View file

@ -13,7 +13,7 @@ from codeflash.code_utils.code_replacer import (
replace_functions_in_file,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, FunctionParent
from codeflash.models.models import CodeOptimizationContext, FunctionParent, get_code_block_splitter
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -41,11 +41,14 @@ class Args:
def test_code_replacement_global_statements():
optimized_code = """import numpy as np
project_root = Path(__file__).parent.parent.resolve()
code_path = (project_root / "code_to_optimize/bubble_sort_optimized.py").resolve()
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(project_root))}
import numpy as np
inconsequential_var = '123'
def sorter(arr):
return arr.sort()"""
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_optimized.py").resolve()
original_code_str = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").read_text(
encoding="utf-8"
)
@ -1666,6 +1669,9 @@ print("Hello world")
def test_global_reassignment() -> None:
root_dir = Path(__file__).parent.parent.resolve()
code_path = (root_dir / "code_to_optimize/global_var_original.py").resolve()
original_code = """a=1
print("Hello world")
def some_fn():
@ -1678,7 +1684,9 @@ class NewClass:
def new_function2(value):
return cst.ensure_type(value, str)
"""
optimized_code = """import numpy as np
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
import numpy as np
def some_fn():
a=np.zeros(10)
print("did something")
@ -1713,7 +1721,6 @@ class NewClass:
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)"""
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
code_path.write_text(original_code, encoding="utf-8")
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
project_root_path = (Path(__file__).parent / "..").resolve()
@ -1753,7 +1760,8 @@ class NewClass:
return cst.ensure_type(value, str)
a=1
"""
optimized_code = """a=2
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
a=2
import numpy as np
def some_fn():
a=np.zeros(10)
@ -1829,7 +1837,8 @@ class NewClass:
def new_function2(value):
return cst.ensure_type(value, str)
"""
optimized_code = """import numpy as np
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
import numpy as np
a=2
def some_fn():
a=np.zeros(10)
@ -1906,7 +1915,8 @@ class NewClass:
def new_function2(value):
return cst.ensure_type(value, str)
"""
optimized_code = """a=2
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
a=2
import numpy as np
def some_fn():
a=np.zeros(10)
@ -1982,7 +1992,8 @@ class NewClass:
def new_function2(value):
return cst.ensure_type(value, str)
"""
optimized_code = """import numpy as np
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
import numpy as np
a=2
def some_fn():
a=np.zeros(10)
@ -2062,7 +2073,8 @@ class NewClass:
def new_function2(value):
return cst.ensure_type(value, str)
"""
optimized_code = """import numpy as np
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
import numpy as np
if 1<2:
a=2
else: