Remove the older imports addition logic, which should prevent duplicate imports.

Also reset the original file if there is a keyboard interrupt during the optimization loop.
This commit is contained in:
Saurabh Misra 2024-06-16 19:17:45 -07:00
parent 96a2a02e26
commit 74db766a43
3 changed files with 243 additions and 204 deletions

View file

@ -26,7 +26,6 @@ class OptimFunctionCollector(cst.CSTVisitor):
self.optim_body: FunctionDef | None = None
self.optim_new_class_functions: list[cst.FunctionDef] = []
self.optim_new_functions: list[cst.FunctionDef] = []
self.optim_imports: list[cst.SimpleStatementLine] = []
self.preexisting_functions = preexisting_functions
self.contextual_functions = contextual_functions.union(
{(self.class_name, self.function_name)},
@ -64,10 +63,6 @@ class OptimFunctionCollector(cst.CSTVisitor):
):
self.optim_new_class_functions.append(child_node)
def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine) -> None:
if isinstance(original_node.body[0], (cst.Import, cst.ImportFrom)):
self.optim_imports.append(original_node)
class OptimFunctionReplacer(cst.CSTTransformer):
def __init__(
@ -75,7 +70,6 @@ class OptimFunctionReplacer(cst.CSTTransformer):
function_name: str,
optim_body: cst.FunctionDef,
optim_new_class_functions: list[cst.FunctionDef],
optim_imports: list[cst.SimpleStatementLine],
optim_new_functions: list[cst.FunctionDef],
class_name: str | None = None,
) -> None:
@ -83,7 +77,6 @@ class OptimFunctionReplacer(cst.CSTTransformer):
self.function_name = function_name
self.optim_body = optim_body
self.optim_new_class_functions = optim_new_class_functions
self.optim_new_imports = optim_imports
self.optim_new_functions = optim_new_functions
self.class_name = class_name
self.depth: int = 0
@ -126,10 +119,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
if len(self.optim_new_imports) == 0:
node = updated_node
else:
node = updated_node.with_changes(body=(*self.optim_new_imports, *updated_node.body))
node = updated_node
max_function_index = None
class_index = None
for index, _node in enumerate(node.body):
@ -190,13 +180,11 @@ def replace_functions_in_file(
continue
if visitor.optim_body is None:
raise ValueError(f"Did not find the function {function_name} in the optimized code")
optim_imports: list[cst.SimpleStatementLine] = [] if i > 0 else visitor.optim_imports
transformer = OptimFunctionReplacer(
visitor.function_name,
visitor.optim_body,
visitor.optim_new_class_functions,
optim_imports,
visitor.optim_new_functions,
class_name=class_name,
)
@ -207,6 +195,31 @@ def replace_functions_in_file(
return source_code
def replace_functions_and_add_imports(
source_code: str,
function_names: list[str],
optimized_code: str,
file_path_of_module_with_function_to_optimize: str,
module_abspath: str,
preexisting_functions: list[str],
contextual_functions: set[tuple[str, str]],
project_root_path: str,
) -> str:
return add_needed_imports_from_module(
optimized_code,
replace_functions_in_file(
source_code,
function_names,
optimized_code,
preexisting_functions,
contextual_functions,
),
file_path_of_module_with_function_to_optimize,
module_abspath,
project_root_path,
)
def replace_function_definitions_in_module(
function_names: list[str],
optimized_code: str,
@ -219,17 +232,14 @@ def replace_function_definitions_in_module(
file: IO[str]
with open(module_abspath, encoding="utf8") as file:
source_code: str = file.read()
new_code: str = add_needed_imports_from_module(
new_code: str = replace_functions_and_add_imports(
source_code,
function_names,
optimized_code,
replace_functions_in_file(
source_code,
function_names,
optimized_code,
preexisting_functions,
contextual_functions,
),
file_path_of_module_with_function_to_optimize,
module_abspath,
preexisting_functions,
contextual_functions,
project_root_path,
)
with open(module_abspath, "w", encoding="utf8") as file:

View file

@ -173,7 +173,6 @@ class Optimizer:
elif self.args.all:
logging.info("✨ All functions have been optimized! ✨")
finally:
# TODO @afik.cohen: Also revert the file/function being optimized if the process did not succeed
for test_file in self.instrumented_unittests_created:
pathlib.Path(test_file).unlink(missing_ok=True)
for test_file in self.test_files_created:
@ -384,110 +383,121 @@ class Optimizer:
logging.info(
f"Determining best optimized candidate (out of {len(candidates)}) for {function_to_optimize.qualified_name} ...",
)
for i, candidate in enumerate(candidates):
j = i + 1
if candidate.source_code is None:
continue
# remove left overs from previous run
pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.bin")).unlink(
missing_ok=True,
)
pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.sqlite")).unlink(
missing_ok=True,
)
logging.info(f"Optimized candidate {j}/{len(candidates)}:")
logging.info(candidate.source_code)
try:
replace_function_definitions_in_module(
function_names=[function_to_optimize.qualified_name],
optimized_code=candidate.source_code,
file_path_of_module_with_function_to_optimize=function_to_optimize.file_path,
module_abspath=function_to_optimize.file_path,
preexisting_functions=code_context.preexisting_functions,
contextual_functions=code_context.contextual_dunder_methods,
project_root_path=self.args.project_root,
try:
for i, candidate in enumerate(candidates):
j = i + 1
if candidate.source_code is None:
continue
# remove left overs from previous run
pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.bin")).unlink(
missing_ok=True,
)
for (
module_abspath,
qualified_names,
) in helper_functions_by_module_abspath.items():
pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.sqlite")).unlink(
missing_ok=True,
)
logging.info(f"Optimized candidate {j}/{len(candidates)}:")
logging.info(candidate.source_code)
try:
replace_function_definitions_in_module(
function_names=list(qualified_names),
function_names=[function_to_optimize.qualified_name],
optimized_code=candidate.source_code,
file_path_of_module_with_function_to_optimize=function_to_optimize.file_path,
module_abspath=module_abspath,
preexisting_functions=[],
module_abspath=function_to_optimize.file_path,
preexisting_functions=code_context.preexisting_functions,
contextual_functions=code_context.contextual_dunder_methods,
project_root_path=self.args.project_root,
)
except (
ValueError,
SyntaxError,
cst.ParserSyntaxError,
AttributeError,
) as e:
logging.error(e) # noqa: TRY400
for (
module_abspath,
qualified_names,
) in helper_functions_by_module_abspath.items():
replace_function_definitions_in_module(
function_names=list(qualified_names),
optimized_code=candidate.source_code,
file_path_of_module_with_function_to_optimize=function_to_optimize.file_path,
module_abspath=module_abspath,
preexisting_functions=[],
contextual_functions=code_context.contextual_dunder_methods,
project_root_path=self.args.project_root,
)
except (
ValueError,
SyntaxError,
cst.ParserSyntaxError,
AttributeError,
) as e:
logging.error(e) # noqa: TRY400
self.write_code_and_helpers(
original_code,
original_helper_code,
function_to_optimize.file_path,
helper_functions_by_module_abspath,
)
continue
# Run generated tests if at least one of them passed
run_generated_tests = False
if original_code_baseline.generated_test_results:
for test_result in original_code_baseline.generated_test_results.test_results:
if test_result.did_pass:
run_generated_tests = True
break
run_results = self.run_optimized_candidate(
optimization_index=j,
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
overall_original_test_results=original_code_baseline.overall_test_results,
original_existing_test_results=original_code_baseline.existing_test_results,
original_generated_test_results=original_code_baseline.generated_test_results,
generated_tests_path=generated_tests_path,
best_runtime_until_now=best_runtime_until_now,
tests_in_file=only_run_this_test_function,
run_generated_tests=run_generated_tests,
)
if not is_successful(run_results):
optimized_runtimes[candidate.optimization_id] = None
is_correct[candidate.optimization_id] = False
speedup_ratios[candidate.optimization_id] = None
else:
candidate_result: OptimizedCandidateResult = run_results.unwrap()
best_test_runtime = candidate_result.best_test_runtime
optimized_runtimes[candidate.optimization_id] = best_test_runtime
is_correct[candidate.optimization_id] = True
speedup_ratios[candidate.optimization_id] = (
original_code_baseline.runtime - best_test_runtime
) / best_test_runtime
logging.info(
f"Candidate runtime measured over {candidate_result.times_run} run{'s' if candidate_result.times_run > 1 else ''}: "
f"{humanize_runtime(best_test_runtime)}, speedup ratio = "
f"{((original_code_baseline.runtime - best_test_runtime) / best_test_runtime):.3f}",
)
if speedup_critic(candidate_result, original_code_baseline.runtime, best_runtime_until_now):
best_optimization = BestOptimization(
candidate=candidate,
helper_functions=code_context.helper_functions,
runtime=best_test_runtime,
winning_test_results=candidate_result.best_test_results,
)
best_runtime_until_now = best_test_runtime
self.write_code_and_helpers(
original_code,
original_helper_code,
function_to_optimize.file_path,
helper_functions_by_module_abspath,
)
continue
# Run generated tests if at least one of them passed
run_generated_tests = False
if original_code_baseline.generated_test_results:
for test_result in original_code_baseline.generated_test_results.test_results:
if test_result.did_pass:
run_generated_tests = True
break
run_results = self.run_optimized_candidate(
optimization_index=j,
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
overall_original_test_results=original_code_baseline.overall_test_results,
original_existing_test_results=original_code_baseline.existing_test_results,
original_generated_test_results=original_code_baseline.generated_test_results,
generated_tests_path=generated_tests_path,
best_runtime_until_now=best_runtime_until_now,
tests_in_file=only_run_this_test_function,
run_generated_tests=run_generated_tests,
)
if not is_successful(run_results):
optimized_runtimes[candidate.optimization_id] = None
is_correct[candidate.optimization_id] = False
speedup_ratios[candidate.optimization_id] = None
else:
candidate_result: OptimizedCandidateResult = run_results.unwrap()
best_test_runtime = candidate_result.best_test_runtime
optimized_runtimes[candidate.optimization_id] = best_test_runtime
is_correct[candidate.optimization_id] = True
speedup_ratios[candidate.optimization_id] = (
original_code_baseline.runtime - best_test_runtime
) / best_test_runtime
logging.info(
f"Candidate runtime measured over {candidate_result.times_run} run{'s' if candidate_result.times_run > 1 else ''}: "
f"{humanize_runtime(best_test_runtime)}, speedup ratio = "
f"{((original_code_baseline.runtime - best_test_runtime) / best_test_runtime):.3f}",
)
if speedup_critic(candidate_result, original_code_baseline.runtime, best_runtime_until_now):
best_optimization = BestOptimization(
candidate=candidate,
helper_functions=code_context.helper_functions,
runtime=best_test_runtime,
winning_test_results=candidate_result.best_test_results,
)
best_runtime_until_now = best_test_runtime
logging.info("----------------")
except KeyboardInterrupt as e:
self.write_code_and_helpers(
original_code,
original_helper_code,
function_to_optimize.file_path,
helper_functions_by_module_abspath,
)
logging.info("----------------")
logging.error(f"Optimization interrupted: {e}")
raise e
self.aiservice_client.log_results(
function_trace_id=function_trace_id,
speedup_ratio=speedup_ratios,

View file

@ -4,10 +4,8 @@ import os
from argparse import Namespace
from pathlib import Path
from returns.pipeline import is_successful
from codeflash.code_utils.code_replacer import replace_functions_in_file
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionParent
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
from codeflash.optimization.optimizer import Optimizer
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
@ -37,9 +35,7 @@ class NewClass:
print("Hello world")
"""
expected = """import libcst as cst
from typing import Optional
class NewClass:
expected = """class NewClass:
def __init__(self, name):
self.name = name
def new_function(self, value):
@ -56,12 +52,15 @@ print("Hello world")
function_name: str = "NewClass.new_function"
preexisting_functions: list[str] = ["new_function"]
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
new_code: str = replace_functions_in_file(
original_code,
[function_name],
optim_code,
preexisting_functions,
contextual_functions,
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
)
assert new_code == expected
@ -95,9 +94,7 @@ class NewClass:
print("Hello world")
"""
expected = """import libcst as cst
from typing import Optional
from OtherModule import other_function
expected = """from OtherModule import other_function
class NewClass:
def __init__(self, name):
@ -116,12 +113,15 @@ print("Hello world")
function_name: str = "NewClass.new_function"
preexisting_functions: list[str] = ["new_function", "other_function"]
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
new_code: str = replace_functions_in_file(
original_code,
[function_name],
optim_code,
preexisting_functions,
contextual_functions,
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
)
assert new_code == expected
@ -139,7 +139,7 @@ def other_function(st):
class NewClass:
def __init__(self, name):
self.name = name
def new_function(self, value):
def new_function(self, value: cst.Name):
return other_function(self.name)
def new_function2(value):
return value
@ -158,10 +158,7 @@ def other_function(st):
print("Salut monde")
"""
expected = """import libcst as cst
from typing import Optional
import libcst as cst
from typing import Mandatory
expected = """from typing import Mandatory
print("Au revoir")
@ -177,12 +174,15 @@ print("Salut monde")
function_names: list[str] = ["module.other_function"]
preexisting_functions: list[str] = []
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
original_code,
function_names,
optim_code,
preexisting_functions,
contextual_functions,
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
)
assert new_code == expected
@ -194,7 +194,7 @@ from typing import Optional
def totally_new_function(value):
return value
def yet_another_function(values):
def yet_another_function(values: Optional[str]):
return len(values) + 2
def other_function(st):
@ -222,14 +222,11 @@ def other_function(st):
print("Salut monde")
"""
expected = """import libcst as cst
from typing import Optional
import libcst as cst
from typing import Mandatory
expected = """from typing import Optional, Mandatory
print("Au revoir")
def yet_another_function(values):
def yet_another_function(values: Optional[str]):
return len(values) + 2
def other_function(st):
@ -241,12 +238,15 @@ print("Salut monde")
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
preexisting_functions: list[str] = []
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
original_code,
function_names,
optim_code,
preexisting_functions,
contextual_functions,
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
)
assert new_code == expected
@ -291,12 +291,15 @@ def supersort(doink):
function_names: list[str] = ["sorter_deps"]
preexisting_functions: list[str] = ["sorter_deps"]
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
original_code,
function_names,
optim_code,
preexisting_functions,
contextual_functions,
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
)
assert new_code == expected
@ -338,10 +341,7 @@ def blab(st):
print("Not cool")
"""
expected_main = """import libcst as cst
from typing import Optional
import libcst as cst
from typing import Mandatory
expected_main = """from typing import Mandatory
from helper import blob
print("Au revoir")
@ -355,9 +355,7 @@ def other_function(st):
print("Salut monde")
"""
expected_helper = """import libcst as cst
from typing import Optional
import numpy as np
expected_helper = """import numpy as np
print("Cool")
@ -369,21 +367,27 @@ def blab(st):
print("Not cool")
"""
new_main_code: str = replace_functions_in_file(
original_code_main,
["other_function"],
optim_code,
["other_function", "yet_another_function", "blob"],
set(),
new_main_code: str = replace_functions_and_add_imports(
source_code=original_code_main,
function_names=["other_function"],
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
preexisting_functions=["other_function", "yet_another_function", "blob"],
contextual_functions=set(),
project_root_path=str(Path(__file__).resolve().parent.resolve()),
)
assert new_main_code == expected_main
new_helper_code: str = replace_functions_in_file(
original_code_helper,
["blob"],
optim_code,
[],
set(),
new_helper_code: str = replace_functions_and_add_imports(
source_code=original_code_helper,
function_names=["blob"],
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
preexisting_functions=[],
contextual_functions=set(),
project_root_path=str(Path(__file__).resolve().parent.resolve()),
)
assert new_helper_code == expected_helper
@ -578,12 +582,15 @@ class CacheConfig(BaseConfig):
("CacheConfig", "__init__"),
("CacheInitConfig", "__init__"),
}
new_code: str = replace_functions_in_file(
original_code,
function_names,
optim_code,
preexisting_functions,
contextual_functions,
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
)
assert new_code == expected
@ -649,12 +656,15 @@ def test_test_libcst_code_replacement8() -> None:
"_hamming_distance",
]
contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
original_code,
function_names,
optim_code,
preexisting_functions,
contextual_functions,
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
)
assert new_code == expected
@ -663,7 +673,7 @@ def test_test_libcst_code_replacement9() -> None:
optim_code = """import libcst as cst
from typing import Optional
def totally_new_function(value):
def totally_new_function(value: Optional[str]):
return value
class NewClass:
@ -672,7 +682,7 @@ class NewClass:
def __call__(self, value):
return self.name
def new_function2(value):
return value
return cst.ensure_type(value, str)
"""
original_code = """class NewClass:
@ -685,15 +695,16 @@ print("Hello world")
"""
expected = """import libcst as cst
from typing import Optional
class NewClass:
def __init__(self, name):
self.name = str(name)
def __call__(self, value):
return "I am still old"
def new_function2(value):
return value
return cst.ensure_type(value, str)
def totally_new_function(value):
def totally_new_function(value: Optional[str]):
return value
print("Hello world")
@ -705,12 +716,15 @@ print("Hello world")
("NewClass", "__init__"),
("NewClass", "__call__"),
}
new_code: str = replace_functions_in_file(
original_code,
[function_name],
optim_code,
preexisting_functions,
contextual_functions,
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
)
assert new_code == expected
@ -764,11 +778,16 @@ class MainClass:
experiment_id=None,
),
)
func_top_optimize = FunctionToOptimize(function_name="main_method", file_path=str(file_path),
parents=[FunctionParent("MainClass", "ClassDef")])
func_top_optimize = FunctionToOptimize(
function_name="main_method",
file_path=str(file_path),
parents=[FunctionParent("MainClass", "ClassDef")],
)
with open(file_path) as f:
original_code = f.read()
code_context = opt.get_code_optimization_context(function_to_optimize=func_top_optimize,
project_root=str(file_path.parent),
original_source_code=original_code).unwrap()
code_context = opt.get_code_optimization_context(
function_to_optimize=func_top_optimize,
project_root=str(file_path.parent),
original_source_code=original_code,
).unwrap()
assert code_context.code_to_optimize_with_helpers == get_code_output