diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 25cac0ab1..a8620819f 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -26,7 +26,6 @@ class OptimFunctionCollector(cst.CSTVisitor): self.optim_body: FunctionDef | None = None self.optim_new_class_functions: list[cst.FunctionDef] = [] self.optim_new_functions: list[cst.FunctionDef] = [] - self.optim_imports: list[cst.SimpleStatementLine] = [] self.preexisting_functions = preexisting_functions self.contextual_functions = contextual_functions.union( {(self.class_name, self.function_name)}, @@ -64,10 +63,6 @@ class OptimFunctionCollector(cst.CSTVisitor): ): self.optim_new_class_functions.append(child_node) - def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine) -> None: - if isinstance(original_node.body[0], (cst.Import, cst.ImportFrom)): - self.optim_imports.append(original_node) - class OptimFunctionReplacer(cst.CSTTransformer): def __init__( @@ -75,7 +70,6 @@ class OptimFunctionReplacer(cst.CSTTransformer): function_name: str, optim_body: cst.FunctionDef, optim_new_class_functions: list[cst.FunctionDef], - optim_imports: list[cst.SimpleStatementLine], optim_new_functions: list[cst.FunctionDef], class_name: str | None = None, ) -> None: @@ -83,7 +77,6 @@ class OptimFunctionReplacer(cst.CSTTransformer): self.function_name = function_name self.optim_body = optim_body self.optim_new_class_functions = optim_new_class_functions - self.optim_new_imports = optim_imports self.optim_new_functions = optim_new_functions self.class_name = class_name self.depth: int = 0 @@ -126,10 +119,7 @@ class OptimFunctionReplacer(cst.CSTTransformer): return updated_node def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: - if len(self.optim_new_imports) == 0: - node = updated_node - else: - node = updated_node.with_changes(body=(*self.optim_new_imports, *updated_node.body)) + node = updated_node max_function_index = None class_index = None for index, _node in enumerate(node.body): @@ -190,13 +180,11 @@ def replace_functions_in_file( continue if visitor.optim_body is None: raise ValueError(f"Did not find the function {function_name} in the optimized code") - optim_imports: list[cst.SimpleStatementLine] = [] if i > 0 else visitor.optim_imports transformer = OptimFunctionReplacer( visitor.function_name, visitor.optim_body, visitor.optim_new_class_functions, - optim_imports, visitor.optim_new_functions, class_name=class_name, ) @@ -207,6 +195,31 @@ def replace_functions_in_file( return source_code +def replace_functions_and_add_imports( + source_code: str, + function_names: list[str], + optimized_code: str, + file_path_of_module_with_function_to_optimize: str, + module_abspath: str, + preexisting_functions: list[str], + contextual_functions: set[tuple[str, str]], + project_root_path: str, +) -> str: + return add_needed_imports_from_module( + optimized_code, + replace_functions_in_file( + source_code, + function_names, + optimized_code, + preexisting_functions, + contextual_functions, + ), + file_path_of_module_with_function_to_optimize, + module_abspath, + project_root_path, + ) + + def replace_function_definitions_in_module( function_names: list[str], optimized_code: str, @@ -219,17 +232,14 @@ def replace_function_definitions_in_module( file: IO[str] with open(module_abspath, encoding="utf8") as file: source_code: str = file.read() - new_code: str = add_needed_imports_from_module( + new_code: str = replace_functions_and_add_imports( + source_code, + function_names, optimized_code, - replace_functions_in_file( - source_code, - function_names, - optimized_code, - preexisting_functions, - contextual_functions, - ), file_path_of_module_with_function_to_optimize, module_abspath, + preexisting_functions, + contextual_functions, project_root_path, ) with open(module_abspath, "w", encoding="utf8") as file: diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index f06bbce50..562d932e1 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -173,7 +173,6 @@ class Optimizer: elif self.args.all: logging.info("✨ All functions have been optimized! ✨") finally: - # TODO @afik.cohen: Also revert the file/function being optimized if the process did not succeed for test_file in self.instrumented_unittests_created: pathlib.Path(test_file).unlink(missing_ok=True) for test_file in self.test_files_created: @@ -384,110 +383,121 @@ class Optimizer: logging.info( f"Determining best optimized candidate (out of {len(candidates)}) for {function_to_optimize.qualified_name} ...", ) - for i, candidate in enumerate(candidates): - j = i + 1 - if candidate.source_code is None: - continue - # remove left overs from previous run - pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.bin")).unlink( - missing_ok=True, - ) - pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.sqlite")).unlink( - missing_ok=True, - ) - logging.info(f"Optimized candidate {j}/{len(candidates)}:") - logging.info(candidate.source_code) - try: - replace_function_definitions_in_module( - function_names=[function_to_optimize.qualified_name], - optimized_code=candidate.source_code, - file_path_of_module_with_function_to_optimize=function_to_optimize.file_path, - module_abspath=function_to_optimize.file_path, - preexisting_functions=code_context.preexisting_functions, - contextual_functions=code_context.contextual_dunder_methods, - project_root_path=self.args.project_root, + try: + for i, candidate in enumerate(candidates): + j = i + 1 + if candidate.source_code is None: + continue + # remove left overs from previous run + pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.bin")).unlink( + missing_ok=True, ) - for ( - module_abspath, - qualified_names, - ) in helper_functions_by_module_abspath.items(): + pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.sqlite")).unlink( + missing_ok=True, + ) + logging.info(f"Optimized candidate {j}/{len(candidates)}:") + logging.info(candidate.source_code) + try: replace_function_definitions_in_module( - function_names=list(qualified_names), + function_names=[function_to_optimize.qualified_name], optimized_code=candidate.source_code, file_path_of_module_with_function_to_optimize=function_to_optimize.file_path, - module_abspath=module_abspath, - preexisting_functions=[], + module_abspath=function_to_optimize.file_path, + preexisting_functions=code_context.preexisting_functions, contextual_functions=code_context.contextual_dunder_methods, project_root_path=self.args.project_root, ) - except ( - ValueError, - SyntaxError, - cst.ParserSyntaxError, - AttributeError, - ) as e: - logging.error(e) # noqa: TRY400 + for ( + module_abspath, + qualified_names, + ) in helper_functions_by_module_abspath.items(): + replace_function_definitions_in_module( + function_names=list(qualified_names), + optimized_code=candidate.source_code, + file_path_of_module_with_function_to_optimize=function_to_optimize.file_path, + module_abspath=module_abspath, + preexisting_functions=[], + contextual_functions=code_context.contextual_dunder_methods, + project_root_path=self.args.project_root, + ) + except ( + ValueError, + SyntaxError, + cst.ParserSyntaxError, + AttributeError, + ) as e: + logging.error(e) # noqa: TRY400 + self.write_code_and_helpers( + original_code, + original_helper_code, + function_to_optimize.file_path, + helper_functions_by_module_abspath, + ) + continue + + # Run generated tests if at least one of them passed + run_generated_tests = False + if original_code_baseline.generated_test_results: + for test_result in original_code_baseline.generated_test_results.test_results: + if test_result.did_pass: + run_generated_tests = True + break + + run_results = self.run_optimized_candidate( + optimization_index=j, + instrumented_unittests_created_for_function=instrumented_unittests_created_for_function, + overall_original_test_results=original_code_baseline.overall_test_results, + original_existing_test_results=original_code_baseline.existing_test_results, + original_generated_test_results=original_code_baseline.generated_test_results, + generated_tests_path=generated_tests_path, + best_runtime_until_now=best_runtime_until_now, + tests_in_file=only_run_this_test_function, + run_generated_tests=run_generated_tests, + ) + if not is_successful(run_results): + optimized_runtimes[candidate.optimization_id] = None + is_correct[candidate.optimization_id] = False + speedup_ratios[candidate.optimization_id] = None + else: + candidate_result: OptimizedCandidateResult = run_results.unwrap() + best_test_runtime = candidate_result.best_test_runtime + optimized_runtimes[candidate.optimization_id] = best_test_runtime + is_correct[candidate.optimization_id] = True + speedup_ratios[candidate.optimization_id] = ( + original_code_baseline.runtime - best_test_runtime + ) / best_test_runtime + logging.info( + f"Candidate runtime measured over {candidate_result.times_run} run{'s' if candidate_result.times_run > 1 else ''}: " + f"{humanize_runtime(best_test_runtime)}, speedup ratio = " + f"{((original_code_baseline.runtime - best_test_runtime) / best_test_runtime):.3f}", + ) + + if speedup_critic(candidate_result, original_code_baseline.runtime, best_runtime_until_now): + best_optimization = BestOptimization( + candidate=candidate, + helper_functions=code_context.helper_functions, + runtime=best_test_runtime, + winning_test_results=candidate_result.best_test_results, + ) + best_runtime_until_now = best_test_runtime + self.write_code_and_helpers( original_code, original_helper_code, function_to_optimize.file_path, helper_functions_by_module_abspath, ) - continue - - # Run generated tests if at least one of them passed - run_generated_tests = False - if original_code_baseline.generated_test_results: - for test_result in original_code_baseline.generated_test_results.test_results: - if test_result.did_pass: - run_generated_tests = True - break - - run_results = self.run_optimized_candidate( - optimization_index=j, - instrumented_unittests_created_for_function=instrumented_unittests_created_for_function, - overall_original_test_results=original_code_baseline.overall_test_results, - original_existing_test_results=original_code_baseline.existing_test_results, - original_generated_test_results=original_code_baseline.generated_test_results, - generated_tests_path=generated_tests_path, - best_runtime_until_now=best_runtime_until_now, - tests_in_file=only_run_this_test_function, - run_generated_tests=run_generated_tests, - ) - if not is_successful(run_results): - optimized_runtimes[candidate.optimization_id] = None - is_correct[candidate.optimization_id] = False - speedup_ratios[candidate.optimization_id] = None - else: - candidate_result: OptimizedCandidateResult = run_results.unwrap() - best_test_runtime = candidate_result.best_test_runtime - optimized_runtimes[candidate.optimization_id] = best_test_runtime - is_correct[candidate.optimization_id] = True - speedup_ratios[candidate.optimization_id] = ( - original_code_baseline.runtime - best_test_runtime - ) / best_test_runtime - logging.info( - f"Candidate runtime measured over {candidate_result.times_run} run{'s' if candidate_result.times_run > 1 else ''}: " - f"{humanize_runtime(best_test_runtime)}, speedup ratio = " - f"{((original_code_baseline.runtime - best_test_runtime) / best_test_runtime):.3f}", - ) - - if speedup_critic(candidate_result, original_code_baseline.runtime, best_runtime_until_now): - best_optimization = BestOptimization( - candidate=candidate, - helper_functions=code_context.helper_functions, - runtime=best_test_runtime, - winning_test_results=candidate_result.best_test_results, - ) - best_runtime_until_now = best_test_runtime - + logging.info("----------------") + except KeyboardInterrupt as e: self.write_code_and_helpers( original_code, original_helper_code, function_to_optimize.file_path, helper_functions_by_module_abspath, ) - logging.info("----------------") + logging.error(f"Optimization interrupted: {e}") + raise e + self.aiservice_client.log_results( function_trace_id=function_trace_id, speedup_ratio=speedup_ratios, diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 20cf653aa..3ed333f2f 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -4,10 +4,8 @@ import os from argparse import Namespace from pathlib import Path -from returns.pipeline import is_successful - -from codeflash.code_utils.code_replacer import replace_functions_in_file -from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionParent +from codeflash.code_utils.code_replacer import replace_functions_and_add_imports +from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize from codeflash.optimization.optimizer import Optimizer os.environ["CODEFLASH_API_KEY"] = "cf-test-key" @@ -37,9 +35,7 @@ class NewClass: print("Hello world") """ - expected = """import libcst as cst -from typing import Optional -class NewClass: + expected = """class NewClass: def __init__(self, name): self.name = name def new_function(self, value): @@ -56,12 +52,15 @@ print("Hello world") function_name: str = "NewClass.new_function" preexisting_functions: list[str] = ["new_function"] contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")} - new_code: str = replace_functions_in_file( - original_code, - [function_name], - optim_code, - preexisting_functions, - contextual_functions, + new_code: str = replace_functions_and_add_imports( + source_code=original_code, + function_names=[function_name], + optimized_code=optim_code, + file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), + module_abspath=str(Path(__file__).resolve()), + preexisting_functions=preexisting_functions, + contextual_functions=contextual_functions, + project_root_path=str(Path(__file__).resolve().parent.resolve()), ) assert new_code == expected @@ -95,9 +94,7 @@ class NewClass: print("Hello world") """ - expected = """import libcst as cst -from typing import Optional -from OtherModule import other_function + expected = """from OtherModule import other_function class NewClass: def __init__(self, name): @@ -116,12 +113,15 @@ print("Hello world") function_name: str = "NewClass.new_function" preexisting_functions: list[str] = ["new_function", "other_function"] contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")} - new_code: str = replace_functions_in_file( - original_code, - [function_name], - optim_code, - preexisting_functions, - contextual_functions, + new_code: str = replace_functions_and_add_imports( + source_code=original_code, + function_names=[function_name], + optimized_code=optim_code, + file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), + module_abspath=str(Path(__file__).resolve()), + preexisting_functions=preexisting_functions, + contextual_functions=contextual_functions, + project_root_path=str(Path(__file__).resolve().parent.resolve()), ) assert new_code == expected @@ -139,7 +139,7 @@ def other_function(st): class NewClass: def __init__(self, name): self.name = name - def new_function(self, value): + def new_function(self, value: cst.Name): return other_function(self.name) def new_function2(value): return value @@ -158,10 +158,7 @@ def other_function(st): print("Salut monde") """ - expected = """import libcst as cst -from typing import Optional -import libcst as cst -from typing import Mandatory + expected = """from typing import Mandatory print("Au revoir") @@ -177,12 +174,15 @@ print("Salut monde") function_names: list[str] = ["module.other_function"] preexisting_functions: list[str] = [] contextual_functions: set[tuple[str, str]] = set() - new_code: str = replace_functions_in_file( - original_code, - function_names, - optim_code, - preexisting_functions, - contextual_functions, + new_code: str = replace_functions_and_add_imports( + source_code=original_code, + function_names=function_names, + optimized_code=optim_code, + file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), + module_abspath=str(Path(__file__).resolve()), + preexisting_functions=preexisting_functions, + contextual_functions=contextual_functions, + project_root_path=str(Path(__file__).resolve().parent.resolve()), ) assert new_code == expected @@ -194,7 +194,7 @@ from typing import Optional def totally_new_function(value): return value -def yet_another_function(values): +def yet_another_function(values: Optional[str]): return len(values) + 2 def other_function(st): @@ -222,14 +222,11 @@ def other_function(st): print("Salut monde") """ - expected = """import libcst as cst -from typing import Optional -import libcst as cst -from typing import Mandatory + expected = """from typing import Optional, Mandatory print("Au revoir") -def yet_another_function(values): +def yet_another_function(values: Optional[str]): return len(values) + 2 def other_function(st): @@ -241,12 +238,15 @@ print("Salut monde") function_names: list[str] = ["module.yet_another_function", "module.other_function"] preexisting_functions: list[str] = [] contextual_functions: set[tuple[str, str]] = set() - new_code: str = replace_functions_in_file( - original_code, - function_names, - optim_code, - preexisting_functions, - contextual_functions, + new_code: str = replace_functions_and_add_imports( + source_code=original_code, + function_names=function_names, + optimized_code=optim_code, + file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), + module_abspath=str(Path(__file__).resolve()), + preexisting_functions=preexisting_functions, + contextual_functions=contextual_functions, + project_root_path=str(Path(__file__).resolve().parent.resolve()), ) assert new_code == expected @@ -291,12 +291,15 @@ def supersort(doink): function_names: list[str] = ["sorter_deps"] preexisting_functions: list[str] = ["sorter_deps"] contextual_functions: set[tuple[str, str]] = set() - new_code: str = replace_functions_in_file( - original_code, - function_names, - optim_code, - preexisting_functions, - contextual_functions, + new_code: str = replace_functions_and_add_imports( + source_code=original_code, + function_names=function_names, + optimized_code=optim_code, + file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), + module_abspath=str(Path(__file__).resolve()), + preexisting_functions=preexisting_functions, + contextual_functions=contextual_functions, + project_root_path=str(Path(__file__).resolve().parent.resolve()), ) assert new_code == expected @@ -338,10 +341,7 @@ def blab(st): print("Not cool") """ - expected_main = """import libcst as cst -from typing import Optional -import libcst as cst -from typing import Mandatory + expected_main = """from typing import Mandatory from helper import blob print("Au revoir") @@ -355,9 +355,7 @@ def other_function(st): print("Salut monde") """ - expected_helper = """import libcst as cst -from typing import Optional -import numpy as np + expected_helper = """import numpy as np print("Cool") @@ -369,21 +367,27 @@ def blab(st): print("Not cool") """ - new_main_code: str = replace_functions_in_file( - original_code_main, - ["other_function"], - optim_code, - ["other_function", "yet_another_function", "blob"], - set(), + new_main_code: str = replace_functions_and_add_imports( + source_code=original_code_main, + function_names=["other_function"], + optimized_code=optim_code, + file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), + module_abspath=str(Path(__file__).resolve()), + preexisting_functions=["other_function", "yet_another_function", "blob"], + contextual_functions=set(), + project_root_path=str(Path(__file__).resolve().parent.resolve()), ) assert new_main_code == expected_main - new_helper_code: str = replace_functions_in_file( - original_code_helper, - ["blob"], - optim_code, - [], - set(), + new_helper_code: str = replace_functions_and_add_imports( + source_code=original_code_helper, + function_names=["blob"], + optimized_code=optim_code, + file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), + module_abspath=str(Path(__file__).resolve()), + preexisting_functions=[], + contextual_functions=set(), + project_root_path=str(Path(__file__).resolve().parent.resolve()), ) assert new_helper_code == expected_helper @@ -578,12 +582,15 @@ class CacheConfig(BaseConfig): ("CacheConfig", "__init__"), ("CacheInitConfig", "__init__"), } - new_code: str = replace_functions_in_file( - original_code, - function_names, - optim_code, - preexisting_functions, - contextual_functions, + new_code: str = replace_functions_and_add_imports( + source_code=original_code, + function_names=function_names, + optimized_code=optim_code, + file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), + module_abspath=str(Path(__file__).resolve()), + preexisting_functions=preexisting_functions, + contextual_functions=contextual_functions, + project_root_path=str(Path(__file__).resolve().parent.resolve()), ) assert new_code == expected @@ -649,12 +656,15 @@ def test_test_libcst_code_replacement8() -> None: "_hamming_distance", ] contextual_functions: set[tuple[str, str]] = set() - new_code: str = replace_functions_in_file( - original_code, - function_names, - optim_code, - preexisting_functions, - contextual_functions, + new_code: str = replace_functions_and_add_imports( + source_code=original_code, + function_names=function_names, + optimized_code=optim_code, + file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), + module_abspath=str(Path(__file__).resolve()), + preexisting_functions=preexisting_functions, + contextual_functions=contextual_functions, + project_root_path=str(Path(__file__).resolve().parent.resolve()), ) assert new_code == expected @@ -663,7 +673,7 @@ def test_test_libcst_code_replacement9() -> None: optim_code = """import libcst as cst from typing import Optional -def totally_new_function(value): +def totally_new_function(value: Optional[str]): return value class NewClass: @@ -672,7 +682,7 @@ class NewClass: def __call__(self, value): return self.name def new_function2(value): - return value + return cst.ensure_type(value, str) """ original_code = """class NewClass: @@ -685,15 +695,16 @@ print("Hello world") """ expected = """import libcst as cst from typing import Optional + class NewClass: def __init__(self, name): self.name = str(name) def __call__(self, value): return "I am still old" def new_function2(value): - return value + return cst.ensure_type(value, str) -def totally_new_function(value): +def totally_new_function(value: Optional[str]): return value print("Hello world") @@ -705,12 +716,15 @@ print("Hello world") ("NewClass", "__init__"), ("NewClass", "__call__"), } - new_code: str = replace_functions_in_file( - original_code, - [function_name], - optim_code, - preexisting_functions, - contextual_functions, + new_code: str = replace_functions_and_add_imports( + source_code=original_code, + function_names=[function_name], + optimized_code=optim_code, + file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), + module_abspath=str(Path(__file__).resolve()), + preexisting_functions=preexisting_functions, + contextual_functions=contextual_functions, + project_root_path=str(Path(__file__).resolve().parent.resolve()), ) assert new_code == expected @@ -764,11 +778,16 @@ class MainClass: experiment_id=None, ), ) - func_top_optimize = FunctionToOptimize(function_name="main_method", file_path=str(file_path), - parents=[FunctionParent("MainClass", "ClassDef")]) + func_top_optimize = FunctionToOptimize( + function_name="main_method", + file_path=str(file_path), + parents=[FunctionParent("MainClass", "ClassDef")], + ) with open(file_path) as f: original_code = f.read() - code_context = opt.get_code_optimization_context(function_to_optimize=func_top_optimize, - project_root=str(file_path.parent), - original_source_code=original_code).unwrap() + code_context = opt.get_code_optimization_context( + function_to_optimize=func_top_optimize, + project_root=str(file_path.parent), + original_source_code=original_code, + ).unwrap() assert code_context.code_to_optimize_with_helpers == get_code_output