mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix code replacement tests
This commit is contained in:
parent
99cd9dc706
commit
330bf91e73
3 changed files with 35 additions and 16 deletions
|
|
@ -169,8 +169,11 @@ class CodeStringsMarkdown(BaseModel):
|
|||
]
|
||||
)
|
||||
|
||||
def path_to_code_string(self) -> dict[str, str]:
|
||||
return {code_string.file_path: code_string.code for code_string in self.code_strings}
|
||||
|
||||
@staticmethod
|
||||
def from_str_with_markers(code_with_markers: str) -> list[CodeString]:
|
||||
def from_str_with_markers(code_with_markers: str) -> CodeStringsMarkdown:
|
||||
pattern = rf"{SPLITTER_MARKER}([^\n]+)\n"
|
||||
matches = list(re.finditer(pattern, code_with_markers))
|
||||
|
||||
|
|
@ -181,7 +184,7 @@ class CodeStringsMarkdown(BaseModel):
|
|||
file_path = match.group(1).strip()
|
||||
code = code_with_markers[start:end].lstrip("\n")
|
||||
results.append(CodeString(file_path=file_path, code=code))
|
||||
return results
|
||||
return CodeStringsMarkdown(code_strings=results)
|
||||
|
||||
|
||||
class CodeOptimizationContext(BaseModel):
|
||||
|
|
|
|||
|
|
@ -621,18 +621,22 @@ class FunctionOptimizer:
|
|||
read_writable_functions_by_file_path[self.function_to_optimize.file_path].add(
|
||||
self.function_to_optimize.qualified_name
|
||||
)
|
||||
code_strings = CodeStringsMarkdown.from_str_with_markers(optimized_code)
|
||||
optimized_code_dict = {code_string.file_path: code_string.code for code_string in code_strings}
|
||||
logger.debug(f"Optimized code: {optimized_code_dict}")
|
||||
file_to_code_context = CodeStringsMarkdown.from_str_with_markers(optimized_code).path_to_code_string()
|
||||
logger.debug(f"Optimized code: {file_to_code_context}")
|
||||
for helper_function in code_context.helper_functions:
|
||||
if helper_function.jedi_definition.type != "class":
|
||||
read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name)
|
||||
for module_abspath, qualified_names in read_writable_functions_by_file_path.items():
|
||||
relative_module_path = module_abspath.relative_to(self.project_root)
|
||||
logger.debug(f"applying optimized code to: {relative_module_path}")
|
||||
|
||||
optimized_code = file_to_code_context.get(relative_module_path)
|
||||
if not optimized_code:
|
||||
msg = f"Optimized code not found for {relative_module_path}, existing files in the context are: {list(file_to_code_context.keys())}, re-check your 'split markers'"
|
||||
raise ValueError(msg)
|
||||
did_update |= replace_function_definitions_in_module(
|
||||
function_names=list(qualified_names),
|
||||
optimized_code=optimized_code_dict.get(relative_module_path),
|
||||
optimized_code=optimized_code,
|
||||
module_abspath=module_abspath,
|
||||
preexisting_objects=code_context.preexisting_objects,
|
||||
project_root_path=self.project_root,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from codeflash.code_utils.code_replacer import (
|
|||
replace_functions_in_file,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext, FunctionParent
|
||||
from codeflash.models.models import CodeOptimizationContext, FunctionParent, get_code_block_splitter
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
|
@ -41,11 +41,14 @@ class Args:
|
|||
|
||||
|
||||
def test_code_replacement_global_statements():
|
||||
optimized_code = """import numpy as np
|
||||
project_root = Path(__file__).parent.parent.resolve()
|
||||
code_path = (project_root / "code_to_optimize/bubble_sort_optimized.py").resolve()
|
||||
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(project_root))}
|
||||
import numpy as np
|
||||
|
||||
inconsequential_var = '123'
|
||||
def sorter(arr):
|
||||
return arr.sort()"""
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_optimized.py").resolve()
|
||||
original_code_str = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").read_text(
|
||||
encoding="utf-8"
|
||||
)
|
||||
|
|
@ -1666,6 +1669,9 @@ print("Hello world")
|
|||
|
||||
|
||||
def test_global_reassignment() -> None:
|
||||
root_dir = Path(__file__).parent.parent.resolve()
|
||||
code_path = (root_dir / "code_to_optimize/global_var_original.py").resolve()
|
||||
|
||||
original_code = """a=1
|
||||
print("Hello world")
|
||||
def some_fn():
|
||||
|
|
@ -1678,7 +1684,9 @@ class NewClass:
|
|||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
"""
|
||||
optimized_code = """import numpy as np
|
||||
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
|
||||
import numpy as np
|
||||
|
||||
def some_fn():
|
||||
a=np.zeros(10)
|
||||
print("did something")
|
||||
|
|
@ -1713,7 +1721,6 @@ class NewClass:
|
|||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)"""
|
||||
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
|
||||
code_path.write_text(original_code, encoding="utf-8")
|
||||
tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/")
|
||||
project_root_path = (Path(__file__).parent / "..").resolve()
|
||||
|
|
@ -1753,7 +1760,8 @@ class NewClass:
|
|||
return cst.ensure_type(value, str)
|
||||
a=1
|
||||
"""
|
||||
optimized_code = """a=2
|
||||
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
|
||||
a=2
|
||||
import numpy as np
|
||||
def some_fn():
|
||||
a=np.zeros(10)
|
||||
|
|
@ -1829,7 +1837,8 @@ class NewClass:
|
|||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
"""
|
||||
optimized_code = """import numpy as np
|
||||
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
|
||||
import numpy as np
|
||||
a=2
|
||||
def some_fn():
|
||||
a=np.zeros(10)
|
||||
|
|
@ -1906,7 +1915,8 @@ class NewClass:
|
|||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
"""
|
||||
optimized_code = """a=2
|
||||
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
|
||||
a=2
|
||||
import numpy as np
|
||||
def some_fn():
|
||||
a=np.zeros(10)
|
||||
|
|
@ -1982,7 +1992,8 @@ class NewClass:
|
|||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
"""
|
||||
optimized_code = """import numpy as np
|
||||
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
|
||||
import numpy as np
|
||||
a=2
|
||||
def some_fn():
|
||||
a=np.zeros(10)
|
||||
|
|
@ -2062,7 +2073,8 @@ class NewClass:
|
|||
def new_function2(value):
|
||||
return cst.ensure_type(value, str)
|
||||
"""
|
||||
optimized_code = """import numpy as np
|
||||
optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))}
|
||||
import numpy as np
|
||||
if 1<2:
|
||||
a=2
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in a new issue