Remove the older imports addition logic, which should prevent duplicate imports.
Also reset the original file if there is a keyboard interrupt during the optimization loop.
This commit is contained in:
parent
96a2a02e26
commit
74db766a43
3 changed files with 243 additions and 204 deletions
|
|
@ -26,7 +26,6 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
||||||
self.optim_body: FunctionDef | None = None
|
self.optim_body: FunctionDef | None = None
|
||||||
self.optim_new_class_functions: list[cst.FunctionDef] = []
|
self.optim_new_class_functions: list[cst.FunctionDef] = []
|
||||||
self.optim_new_functions: list[cst.FunctionDef] = []
|
self.optim_new_functions: list[cst.FunctionDef] = []
|
||||||
self.optim_imports: list[cst.SimpleStatementLine] = []
|
|
||||||
self.preexisting_functions = preexisting_functions
|
self.preexisting_functions = preexisting_functions
|
||||||
self.contextual_functions = contextual_functions.union(
|
self.contextual_functions = contextual_functions.union(
|
||||||
{(self.class_name, self.function_name)},
|
{(self.class_name, self.function_name)},
|
||||||
|
|
@ -64,10 +63,6 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
||||||
):
|
):
|
||||||
self.optim_new_class_functions.append(child_node)
|
self.optim_new_class_functions.append(child_node)
|
||||||
|
|
||||||
def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine) -> None:
|
|
||||||
if isinstance(original_node.body[0], (cst.Import, cst.ImportFrom)):
|
|
||||||
self.optim_imports.append(original_node)
|
|
||||||
|
|
||||||
|
|
||||||
class OptimFunctionReplacer(cst.CSTTransformer):
|
class OptimFunctionReplacer(cst.CSTTransformer):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -75,7 +70,6 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
||||||
function_name: str,
|
function_name: str,
|
||||||
optim_body: cst.FunctionDef,
|
optim_body: cst.FunctionDef,
|
||||||
optim_new_class_functions: list[cst.FunctionDef],
|
optim_new_class_functions: list[cst.FunctionDef],
|
||||||
optim_imports: list[cst.SimpleStatementLine],
|
|
||||||
optim_new_functions: list[cst.FunctionDef],
|
optim_new_functions: list[cst.FunctionDef],
|
||||||
class_name: str | None = None,
|
class_name: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -83,7 +77,6 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
||||||
self.function_name = function_name
|
self.function_name = function_name
|
||||||
self.optim_body = optim_body
|
self.optim_body = optim_body
|
||||||
self.optim_new_class_functions = optim_new_class_functions
|
self.optim_new_class_functions = optim_new_class_functions
|
||||||
self.optim_new_imports = optim_imports
|
|
||||||
self.optim_new_functions = optim_new_functions
|
self.optim_new_functions = optim_new_functions
|
||||||
self.class_name = class_name
|
self.class_name = class_name
|
||||||
self.depth: int = 0
|
self.depth: int = 0
|
||||||
|
|
@ -126,10 +119,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
||||||
return updated_node
|
return updated_node
|
||||||
|
|
||||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||||
if len(self.optim_new_imports) == 0:
|
node = updated_node
|
||||||
node = updated_node
|
|
||||||
else:
|
|
||||||
node = updated_node.with_changes(body=(*self.optim_new_imports, *updated_node.body))
|
|
||||||
max_function_index = None
|
max_function_index = None
|
||||||
class_index = None
|
class_index = None
|
||||||
for index, _node in enumerate(node.body):
|
for index, _node in enumerate(node.body):
|
||||||
|
|
@ -190,13 +180,11 @@ def replace_functions_in_file(
|
||||||
continue
|
continue
|
||||||
if visitor.optim_body is None:
|
if visitor.optim_body is None:
|
||||||
raise ValueError(f"Did not find the function {function_name} in the optimized code")
|
raise ValueError(f"Did not find the function {function_name} in the optimized code")
|
||||||
optim_imports: list[cst.SimpleStatementLine] = [] if i > 0 else visitor.optim_imports
|
|
||||||
|
|
||||||
transformer = OptimFunctionReplacer(
|
transformer = OptimFunctionReplacer(
|
||||||
visitor.function_name,
|
visitor.function_name,
|
||||||
visitor.optim_body,
|
visitor.optim_body,
|
||||||
visitor.optim_new_class_functions,
|
visitor.optim_new_class_functions,
|
||||||
optim_imports,
|
|
||||||
visitor.optim_new_functions,
|
visitor.optim_new_functions,
|
||||||
class_name=class_name,
|
class_name=class_name,
|
||||||
)
|
)
|
||||||
|
|
@ -207,6 +195,31 @@ def replace_functions_in_file(
|
||||||
return source_code
|
return source_code
|
||||||
|
|
||||||
|
|
||||||
|
def replace_functions_and_add_imports(
|
||||||
|
source_code: str,
|
||||||
|
function_names: list[str],
|
||||||
|
optimized_code: str,
|
||||||
|
file_path_of_module_with_function_to_optimize: str,
|
||||||
|
module_abspath: str,
|
||||||
|
preexisting_functions: list[str],
|
||||||
|
contextual_functions: set[tuple[str, str]],
|
||||||
|
project_root_path: str,
|
||||||
|
) -> str:
|
||||||
|
return add_needed_imports_from_module(
|
||||||
|
optimized_code,
|
||||||
|
replace_functions_in_file(
|
||||||
|
source_code,
|
||||||
|
function_names,
|
||||||
|
optimized_code,
|
||||||
|
preexisting_functions,
|
||||||
|
contextual_functions,
|
||||||
|
),
|
||||||
|
file_path_of_module_with_function_to_optimize,
|
||||||
|
module_abspath,
|
||||||
|
project_root_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def replace_function_definitions_in_module(
|
def replace_function_definitions_in_module(
|
||||||
function_names: list[str],
|
function_names: list[str],
|
||||||
optimized_code: str,
|
optimized_code: str,
|
||||||
|
|
@ -219,17 +232,14 @@ def replace_function_definitions_in_module(
|
||||||
file: IO[str]
|
file: IO[str]
|
||||||
with open(module_abspath, encoding="utf8") as file:
|
with open(module_abspath, encoding="utf8") as file:
|
||||||
source_code: str = file.read()
|
source_code: str = file.read()
|
||||||
new_code: str = add_needed_imports_from_module(
|
new_code: str = replace_functions_and_add_imports(
|
||||||
|
source_code,
|
||||||
|
function_names,
|
||||||
optimized_code,
|
optimized_code,
|
||||||
replace_functions_in_file(
|
|
||||||
source_code,
|
|
||||||
function_names,
|
|
||||||
optimized_code,
|
|
||||||
preexisting_functions,
|
|
||||||
contextual_functions,
|
|
||||||
),
|
|
||||||
file_path_of_module_with_function_to_optimize,
|
file_path_of_module_with_function_to_optimize,
|
||||||
module_abspath,
|
module_abspath,
|
||||||
|
preexisting_functions,
|
||||||
|
contextual_functions,
|
||||||
project_root_path,
|
project_root_path,
|
||||||
)
|
)
|
||||||
with open(module_abspath, "w", encoding="utf8") as file:
|
with open(module_abspath, "w", encoding="utf8") as file:
|
||||||
|
|
|
||||||
|
|
@ -173,7 +173,6 @@ class Optimizer:
|
||||||
elif self.args.all:
|
elif self.args.all:
|
||||||
logging.info("✨ All functions have been optimized! ✨")
|
logging.info("✨ All functions have been optimized! ✨")
|
||||||
finally:
|
finally:
|
||||||
# TODO @afik.cohen: Also revert the file/function being optimized if the process did not succeed
|
|
||||||
for test_file in self.instrumented_unittests_created:
|
for test_file in self.instrumented_unittests_created:
|
||||||
pathlib.Path(test_file).unlink(missing_ok=True)
|
pathlib.Path(test_file).unlink(missing_ok=True)
|
||||||
for test_file in self.test_files_created:
|
for test_file in self.test_files_created:
|
||||||
|
|
@ -384,110 +383,121 @@ class Optimizer:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Determining best optimized candidate (out of {len(candidates)}) for {function_to_optimize.qualified_name} ...",
|
f"Determining best optimized candidate (out of {len(candidates)}) for {function_to_optimize.qualified_name} ...",
|
||||||
)
|
)
|
||||||
for i, candidate in enumerate(candidates):
|
try:
|
||||||
j = i + 1
|
for i, candidate in enumerate(candidates):
|
||||||
if candidate.source_code is None:
|
j = i + 1
|
||||||
continue
|
if candidate.source_code is None:
|
||||||
# remove left overs from previous run
|
continue
|
||||||
pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.bin")).unlink(
|
# remove left overs from previous run
|
||||||
missing_ok=True,
|
pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.bin")).unlink(
|
||||||
)
|
missing_ok=True,
|
||||||
pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.sqlite")).unlink(
|
|
||||||
missing_ok=True,
|
|
||||||
)
|
|
||||||
logging.info(f"Optimized candidate {j}/{len(candidates)}:")
|
|
||||||
logging.info(candidate.source_code)
|
|
||||||
try:
|
|
||||||
replace_function_definitions_in_module(
|
|
||||||
function_names=[function_to_optimize.qualified_name],
|
|
||||||
optimized_code=candidate.source_code,
|
|
||||||
file_path_of_module_with_function_to_optimize=function_to_optimize.file_path,
|
|
||||||
module_abspath=function_to_optimize.file_path,
|
|
||||||
preexisting_functions=code_context.preexisting_functions,
|
|
||||||
contextual_functions=code_context.contextual_dunder_methods,
|
|
||||||
project_root_path=self.args.project_root,
|
|
||||||
)
|
)
|
||||||
for (
|
pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.sqlite")).unlink(
|
||||||
module_abspath,
|
missing_ok=True,
|
||||||
qualified_names,
|
)
|
||||||
) in helper_functions_by_module_abspath.items():
|
logging.info(f"Optimized candidate {j}/{len(candidates)}:")
|
||||||
|
logging.info(candidate.source_code)
|
||||||
|
try:
|
||||||
replace_function_definitions_in_module(
|
replace_function_definitions_in_module(
|
||||||
function_names=list(qualified_names),
|
function_names=[function_to_optimize.qualified_name],
|
||||||
optimized_code=candidate.source_code,
|
optimized_code=candidate.source_code,
|
||||||
file_path_of_module_with_function_to_optimize=function_to_optimize.file_path,
|
file_path_of_module_with_function_to_optimize=function_to_optimize.file_path,
|
||||||
module_abspath=module_abspath,
|
module_abspath=function_to_optimize.file_path,
|
||||||
preexisting_functions=[],
|
preexisting_functions=code_context.preexisting_functions,
|
||||||
contextual_functions=code_context.contextual_dunder_methods,
|
contextual_functions=code_context.contextual_dunder_methods,
|
||||||
project_root_path=self.args.project_root,
|
project_root_path=self.args.project_root,
|
||||||
)
|
)
|
||||||
except (
|
for (
|
||||||
ValueError,
|
module_abspath,
|
||||||
SyntaxError,
|
qualified_names,
|
||||||
cst.ParserSyntaxError,
|
) in helper_functions_by_module_abspath.items():
|
||||||
AttributeError,
|
replace_function_definitions_in_module(
|
||||||
) as e:
|
function_names=list(qualified_names),
|
||||||
logging.error(e) # noqa: TRY400
|
optimized_code=candidate.source_code,
|
||||||
|
file_path_of_module_with_function_to_optimize=function_to_optimize.file_path,
|
||||||
|
module_abspath=module_abspath,
|
||||||
|
preexisting_functions=[],
|
||||||
|
contextual_functions=code_context.contextual_dunder_methods,
|
||||||
|
project_root_path=self.args.project_root,
|
||||||
|
)
|
||||||
|
except (
|
||||||
|
ValueError,
|
||||||
|
SyntaxError,
|
||||||
|
cst.ParserSyntaxError,
|
||||||
|
AttributeError,
|
||||||
|
) as e:
|
||||||
|
logging.error(e) # noqa: TRY400
|
||||||
|
self.write_code_and_helpers(
|
||||||
|
original_code,
|
||||||
|
original_helper_code,
|
||||||
|
function_to_optimize.file_path,
|
||||||
|
helper_functions_by_module_abspath,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Run generated tests if at least one of them passed
|
||||||
|
run_generated_tests = False
|
||||||
|
if original_code_baseline.generated_test_results:
|
||||||
|
for test_result in original_code_baseline.generated_test_results.test_results:
|
||||||
|
if test_result.did_pass:
|
||||||
|
run_generated_tests = True
|
||||||
|
break
|
||||||
|
|
||||||
|
run_results = self.run_optimized_candidate(
|
||||||
|
optimization_index=j,
|
||||||
|
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
|
||||||
|
overall_original_test_results=original_code_baseline.overall_test_results,
|
||||||
|
original_existing_test_results=original_code_baseline.existing_test_results,
|
||||||
|
original_generated_test_results=original_code_baseline.generated_test_results,
|
||||||
|
generated_tests_path=generated_tests_path,
|
||||||
|
best_runtime_until_now=best_runtime_until_now,
|
||||||
|
tests_in_file=only_run_this_test_function,
|
||||||
|
run_generated_tests=run_generated_tests,
|
||||||
|
)
|
||||||
|
if not is_successful(run_results):
|
||||||
|
optimized_runtimes[candidate.optimization_id] = None
|
||||||
|
is_correct[candidate.optimization_id] = False
|
||||||
|
speedup_ratios[candidate.optimization_id] = None
|
||||||
|
else:
|
||||||
|
candidate_result: OptimizedCandidateResult = run_results.unwrap()
|
||||||
|
best_test_runtime = candidate_result.best_test_runtime
|
||||||
|
optimized_runtimes[candidate.optimization_id] = best_test_runtime
|
||||||
|
is_correct[candidate.optimization_id] = True
|
||||||
|
speedup_ratios[candidate.optimization_id] = (
|
||||||
|
original_code_baseline.runtime - best_test_runtime
|
||||||
|
) / best_test_runtime
|
||||||
|
logging.info(
|
||||||
|
f"Candidate runtime measured over {candidate_result.times_run} run{'s' if candidate_result.times_run > 1 else ''}: "
|
||||||
|
f"{humanize_runtime(best_test_runtime)}, speedup ratio = "
|
||||||
|
f"{((original_code_baseline.runtime - best_test_runtime) / best_test_runtime):.3f}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if speedup_critic(candidate_result, original_code_baseline.runtime, best_runtime_until_now):
|
||||||
|
best_optimization = BestOptimization(
|
||||||
|
candidate=candidate,
|
||||||
|
helper_functions=code_context.helper_functions,
|
||||||
|
runtime=best_test_runtime,
|
||||||
|
winning_test_results=candidate_result.best_test_results,
|
||||||
|
)
|
||||||
|
best_runtime_until_now = best_test_runtime
|
||||||
|
|
||||||
self.write_code_and_helpers(
|
self.write_code_and_helpers(
|
||||||
original_code,
|
original_code,
|
||||||
original_helper_code,
|
original_helper_code,
|
||||||
function_to_optimize.file_path,
|
function_to_optimize.file_path,
|
||||||
helper_functions_by_module_abspath,
|
helper_functions_by_module_abspath,
|
||||||
)
|
)
|
||||||
continue
|
logging.info("----------------")
|
||||||
|
except KeyboardInterrupt as e:
|
||||||
# Run generated tests if at least one of them passed
|
|
||||||
run_generated_tests = False
|
|
||||||
if original_code_baseline.generated_test_results:
|
|
||||||
for test_result in original_code_baseline.generated_test_results.test_results:
|
|
||||||
if test_result.did_pass:
|
|
||||||
run_generated_tests = True
|
|
||||||
break
|
|
||||||
|
|
||||||
run_results = self.run_optimized_candidate(
|
|
||||||
optimization_index=j,
|
|
||||||
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
|
|
||||||
overall_original_test_results=original_code_baseline.overall_test_results,
|
|
||||||
original_existing_test_results=original_code_baseline.existing_test_results,
|
|
||||||
original_generated_test_results=original_code_baseline.generated_test_results,
|
|
||||||
generated_tests_path=generated_tests_path,
|
|
||||||
best_runtime_until_now=best_runtime_until_now,
|
|
||||||
tests_in_file=only_run_this_test_function,
|
|
||||||
run_generated_tests=run_generated_tests,
|
|
||||||
)
|
|
||||||
if not is_successful(run_results):
|
|
||||||
optimized_runtimes[candidate.optimization_id] = None
|
|
||||||
is_correct[candidate.optimization_id] = False
|
|
||||||
speedup_ratios[candidate.optimization_id] = None
|
|
||||||
else:
|
|
||||||
candidate_result: OptimizedCandidateResult = run_results.unwrap()
|
|
||||||
best_test_runtime = candidate_result.best_test_runtime
|
|
||||||
optimized_runtimes[candidate.optimization_id] = best_test_runtime
|
|
||||||
is_correct[candidate.optimization_id] = True
|
|
||||||
speedup_ratios[candidate.optimization_id] = (
|
|
||||||
original_code_baseline.runtime - best_test_runtime
|
|
||||||
) / best_test_runtime
|
|
||||||
logging.info(
|
|
||||||
f"Candidate runtime measured over {candidate_result.times_run} run{'s' if candidate_result.times_run > 1 else ''}: "
|
|
||||||
f"{humanize_runtime(best_test_runtime)}, speedup ratio = "
|
|
||||||
f"{((original_code_baseline.runtime - best_test_runtime) / best_test_runtime):.3f}",
|
|
||||||
)
|
|
||||||
|
|
||||||
if speedup_critic(candidate_result, original_code_baseline.runtime, best_runtime_until_now):
|
|
||||||
best_optimization = BestOptimization(
|
|
||||||
candidate=candidate,
|
|
||||||
helper_functions=code_context.helper_functions,
|
|
||||||
runtime=best_test_runtime,
|
|
||||||
winning_test_results=candidate_result.best_test_results,
|
|
||||||
)
|
|
||||||
best_runtime_until_now = best_test_runtime
|
|
||||||
|
|
||||||
self.write_code_and_helpers(
|
self.write_code_and_helpers(
|
||||||
original_code,
|
original_code,
|
||||||
original_helper_code,
|
original_helper_code,
|
||||||
function_to_optimize.file_path,
|
function_to_optimize.file_path,
|
||||||
helper_functions_by_module_abspath,
|
helper_functions_by_module_abspath,
|
||||||
)
|
)
|
||||||
logging.info("----------------")
|
logging.error(f"Optimization interrupted: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
self.aiservice_client.log_results(
|
self.aiservice_client.log_results(
|
||||||
function_trace_id=function_trace_id,
|
function_trace_id=function_trace_id,
|
||||||
speedup_ratio=speedup_ratios,
|
speedup_ratio=speedup_ratios,
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,8 @@ import os
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from returns.pipeline import is_successful
|
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||||
|
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||||
from codeflash.code_utils.code_replacer import replace_functions_in_file
|
|
||||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionParent
|
|
||||||
from codeflash.optimization.optimizer import Optimizer
|
from codeflash.optimization.optimizer import Optimizer
|
||||||
|
|
||||||
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
|
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
|
||||||
|
|
@ -37,9 +35,7 @@ class NewClass:
|
||||||
|
|
||||||
print("Hello world")
|
print("Hello world")
|
||||||
"""
|
"""
|
||||||
expected = """import libcst as cst
|
expected = """class NewClass:
|
||||||
from typing import Optional
|
|
||||||
class NewClass:
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
self.name = name
|
self.name = name
|
||||||
def new_function(self, value):
|
def new_function(self, value):
|
||||||
|
|
@ -56,12 +52,15 @@ print("Hello world")
|
||||||
function_name: str = "NewClass.new_function"
|
function_name: str = "NewClass.new_function"
|
||||||
preexisting_functions: list[str] = ["new_function"]
|
preexisting_functions: list[str] = ["new_function"]
|
||||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||||
new_code: str = replace_functions_in_file(
|
new_code: str = replace_functions_and_add_imports(
|
||||||
original_code,
|
source_code=original_code,
|
||||||
[function_name],
|
function_names=[function_name],
|
||||||
optim_code,
|
optimized_code=optim_code,
|
||||||
preexisting_functions,
|
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||||
contextual_functions,
|
module_abspath=str(Path(__file__).resolve()),
|
||||||
|
preexisting_functions=preexisting_functions,
|
||||||
|
contextual_functions=contextual_functions,
|
||||||
|
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||||
)
|
)
|
||||||
assert new_code == expected
|
assert new_code == expected
|
||||||
|
|
||||||
|
|
@ -95,9 +94,7 @@ class NewClass:
|
||||||
|
|
||||||
print("Hello world")
|
print("Hello world")
|
||||||
"""
|
"""
|
||||||
expected = """import libcst as cst
|
expected = """from OtherModule import other_function
|
||||||
from typing import Optional
|
|
||||||
from OtherModule import other_function
|
|
||||||
|
|
||||||
class NewClass:
|
class NewClass:
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
|
|
@ -116,12 +113,15 @@ print("Hello world")
|
||||||
function_name: str = "NewClass.new_function"
|
function_name: str = "NewClass.new_function"
|
||||||
preexisting_functions: list[str] = ["new_function", "other_function"]
|
preexisting_functions: list[str] = ["new_function", "other_function"]
|
||||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||||
new_code: str = replace_functions_in_file(
|
new_code: str = replace_functions_and_add_imports(
|
||||||
original_code,
|
source_code=original_code,
|
||||||
[function_name],
|
function_names=[function_name],
|
||||||
optim_code,
|
optimized_code=optim_code,
|
||||||
preexisting_functions,
|
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||||
contextual_functions,
|
module_abspath=str(Path(__file__).resolve()),
|
||||||
|
preexisting_functions=preexisting_functions,
|
||||||
|
contextual_functions=contextual_functions,
|
||||||
|
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||||
)
|
)
|
||||||
assert new_code == expected
|
assert new_code == expected
|
||||||
|
|
||||||
|
|
@ -139,7 +139,7 @@ def other_function(st):
|
||||||
class NewClass:
|
class NewClass:
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
self.name = name
|
self.name = name
|
||||||
def new_function(self, value):
|
def new_function(self, value: cst.Name):
|
||||||
return other_function(self.name)
|
return other_function(self.name)
|
||||||
def new_function2(value):
|
def new_function2(value):
|
||||||
return value
|
return value
|
||||||
|
|
@ -158,10 +158,7 @@ def other_function(st):
|
||||||
|
|
||||||
print("Salut monde")
|
print("Salut monde")
|
||||||
"""
|
"""
|
||||||
expected = """import libcst as cst
|
expected = """from typing import Mandatory
|
||||||
from typing import Optional
|
|
||||||
import libcst as cst
|
|
||||||
from typing import Mandatory
|
|
||||||
|
|
||||||
print("Au revoir")
|
print("Au revoir")
|
||||||
|
|
||||||
|
|
@ -177,12 +174,15 @@ print("Salut monde")
|
||||||
function_names: list[str] = ["module.other_function"]
|
function_names: list[str] = ["module.other_function"]
|
||||||
preexisting_functions: list[str] = []
|
preexisting_functions: list[str] = []
|
||||||
contextual_functions: set[tuple[str, str]] = set()
|
contextual_functions: set[tuple[str, str]] = set()
|
||||||
new_code: str = replace_functions_in_file(
|
new_code: str = replace_functions_and_add_imports(
|
||||||
original_code,
|
source_code=original_code,
|
||||||
function_names,
|
function_names=function_names,
|
||||||
optim_code,
|
optimized_code=optim_code,
|
||||||
preexisting_functions,
|
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||||
contextual_functions,
|
module_abspath=str(Path(__file__).resolve()),
|
||||||
|
preexisting_functions=preexisting_functions,
|
||||||
|
contextual_functions=contextual_functions,
|
||||||
|
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||||
)
|
)
|
||||||
assert new_code == expected
|
assert new_code == expected
|
||||||
|
|
||||||
|
|
@ -194,7 +194,7 @@ from typing import Optional
|
||||||
def totally_new_function(value):
|
def totally_new_function(value):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def yet_another_function(values):
|
def yet_another_function(values: Optional[str]):
|
||||||
return len(values) + 2
|
return len(values) + 2
|
||||||
|
|
||||||
def other_function(st):
|
def other_function(st):
|
||||||
|
|
@ -222,14 +222,11 @@ def other_function(st):
|
||||||
|
|
||||||
print("Salut monde")
|
print("Salut monde")
|
||||||
"""
|
"""
|
||||||
expected = """import libcst as cst
|
expected = """from typing import Optional, Mandatory
|
||||||
from typing import Optional
|
|
||||||
import libcst as cst
|
|
||||||
from typing import Mandatory
|
|
||||||
|
|
||||||
print("Au revoir")
|
print("Au revoir")
|
||||||
|
|
||||||
def yet_another_function(values):
|
def yet_another_function(values: Optional[str]):
|
||||||
return len(values) + 2
|
return len(values) + 2
|
||||||
|
|
||||||
def other_function(st):
|
def other_function(st):
|
||||||
|
|
@ -241,12 +238,15 @@ print("Salut monde")
|
||||||
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
|
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
|
||||||
preexisting_functions: list[str] = []
|
preexisting_functions: list[str] = []
|
||||||
contextual_functions: set[tuple[str, str]] = set()
|
contextual_functions: set[tuple[str, str]] = set()
|
||||||
new_code: str = replace_functions_in_file(
|
new_code: str = replace_functions_and_add_imports(
|
||||||
original_code,
|
source_code=original_code,
|
||||||
function_names,
|
function_names=function_names,
|
||||||
optim_code,
|
optimized_code=optim_code,
|
||||||
preexisting_functions,
|
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||||
contextual_functions,
|
module_abspath=str(Path(__file__).resolve()),
|
||||||
|
preexisting_functions=preexisting_functions,
|
||||||
|
contextual_functions=contextual_functions,
|
||||||
|
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||||
)
|
)
|
||||||
assert new_code == expected
|
assert new_code == expected
|
||||||
|
|
||||||
|
|
@ -291,12 +291,15 @@ def supersort(doink):
|
||||||
function_names: list[str] = ["sorter_deps"]
|
function_names: list[str] = ["sorter_deps"]
|
||||||
preexisting_functions: list[str] = ["sorter_deps"]
|
preexisting_functions: list[str] = ["sorter_deps"]
|
||||||
contextual_functions: set[tuple[str, str]] = set()
|
contextual_functions: set[tuple[str, str]] = set()
|
||||||
new_code: str = replace_functions_in_file(
|
new_code: str = replace_functions_and_add_imports(
|
||||||
original_code,
|
source_code=original_code,
|
||||||
function_names,
|
function_names=function_names,
|
||||||
optim_code,
|
optimized_code=optim_code,
|
||||||
preexisting_functions,
|
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||||
contextual_functions,
|
module_abspath=str(Path(__file__).resolve()),
|
||||||
|
preexisting_functions=preexisting_functions,
|
||||||
|
contextual_functions=contextual_functions,
|
||||||
|
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||||
)
|
)
|
||||||
assert new_code == expected
|
assert new_code == expected
|
||||||
|
|
||||||
|
|
@ -338,10 +341,7 @@ def blab(st):
|
||||||
|
|
||||||
print("Not cool")
|
print("Not cool")
|
||||||
"""
|
"""
|
||||||
expected_main = """import libcst as cst
|
expected_main = """from typing import Mandatory
|
||||||
from typing import Optional
|
|
||||||
import libcst as cst
|
|
||||||
from typing import Mandatory
|
|
||||||
from helper import blob
|
from helper import blob
|
||||||
|
|
||||||
print("Au revoir")
|
print("Au revoir")
|
||||||
|
|
@ -355,9 +355,7 @@ def other_function(st):
|
||||||
print("Salut monde")
|
print("Salut monde")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
expected_helper = """import libcst as cst
|
expected_helper = """import numpy as np
|
||||||
from typing import Optional
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
print("Cool")
|
print("Cool")
|
||||||
|
|
||||||
|
|
@ -369,21 +367,27 @@ def blab(st):
|
||||||
|
|
||||||
print("Not cool")
|
print("Not cool")
|
||||||
"""
|
"""
|
||||||
new_main_code: str = replace_functions_in_file(
|
new_main_code: str = replace_functions_and_add_imports(
|
||||||
original_code_main,
|
source_code=original_code_main,
|
||||||
["other_function"],
|
function_names=["other_function"],
|
||||||
optim_code,
|
optimized_code=optim_code,
|
||||||
["other_function", "yet_another_function", "blob"],
|
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||||
set(),
|
module_abspath=str(Path(__file__).resolve()),
|
||||||
|
preexisting_functions=["other_function", "yet_another_function", "blob"],
|
||||||
|
contextual_functions=set(),
|
||||||
|
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||||
)
|
)
|
||||||
assert new_main_code == expected_main
|
assert new_main_code == expected_main
|
||||||
|
|
||||||
new_helper_code: str = replace_functions_in_file(
|
new_helper_code: str = replace_functions_and_add_imports(
|
||||||
original_code_helper,
|
source_code=original_code_helper,
|
||||||
["blob"],
|
function_names=["blob"],
|
||||||
optim_code,
|
optimized_code=optim_code,
|
||||||
[],
|
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||||
set(),
|
module_abspath=str(Path(__file__).resolve()),
|
||||||
|
preexisting_functions=[],
|
||||||
|
contextual_functions=set(),
|
||||||
|
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||||
)
|
)
|
||||||
assert new_helper_code == expected_helper
|
assert new_helper_code == expected_helper
|
||||||
|
|
||||||
|
|
@ -578,12 +582,15 @@ class CacheConfig(BaseConfig):
|
||||||
("CacheConfig", "__init__"),
|
("CacheConfig", "__init__"),
|
||||||
("CacheInitConfig", "__init__"),
|
("CacheInitConfig", "__init__"),
|
||||||
}
|
}
|
||||||
new_code: str = replace_functions_in_file(
|
new_code: str = replace_functions_and_add_imports(
|
||||||
original_code,
|
source_code=original_code,
|
||||||
function_names,
|
function_names=function_names,
|
||||||
optim_code,
|
optimized_code=optim_code,
|
||||||
preexisting_functions,
|
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||||
contextual_functions,
|
module_abspath=str(Path(__file__).resolve()),
|
||||||
|
preexisting_functions=preexisting_functions,
|
||||||
|
contextual_functions=contextual_functions,
|
||||||
|
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||||
)
|
)
|
||||||
assert new_code == expected
|
assert new_code == expected
|
||||||
|
|
||||||
|
|
@ -649,12 +656,15 @@ def test_test_libcst_code_replacement8() -> None:
|
||||||
"_hamming_distance",
|
"_hamming_distance",
|
||||||
]
|
]
|
||||||
contextual_functions: set[tuple[str, str]] = set()
|
contextual_functions: set[tuple[str, str]] = set()
|
||||||
new_code: str = replace_functions_in_file(
|
new_code: str = replace_functions_and_add_imports(
|
||||||
original_code,
|
source_code=original_code,
|
||||||
function_names,
|
function_names=function_names,
|
||||||
optim_code,
|
optimized_code=optim_code,
|
||||||
preexisting_functions,
|
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||||
contextual_functions,
|
module_abspath=str(Path(__file__).resolve()),
|
||||||
|
preexisting_functions=preexisting_functions,
|
||||||
|
contextual_functions=contextual_functions,
|
||||||
|
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||||
)
|
)
|
||||||
assert new_code == expected
|
assert new_code == expected
|
||||||
|
|
||||||
|
|
@ -663,7 +673,7 @@ def test_test_libcst_code_replacement9() -> None:
|
||||||
optim_code = """import libcst as cst
|
optim_code = """import libcst as cst
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
def totally_new_function(value):
|
def totally_new_function(value: Optional[str]):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
class NewClass:
|
class NewClass:
|
||||||
|
|
@ -672,7 +682,7 @@ class NewClass:
|
||||||
def __call__(self, value):
|
def __call__(self, value):
|
||||||
return self.name
|
return self.name
|
||||||
def new_function2(value):
|
def new_function2(value):
|
||||||
return value
|
return cst.ensure_type(value, str)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
original_code = """class NewClass:
|
original_code = """class NewClass:
|
||||||
|
|
@ -685,15 +695,16 @@ print("Hello world")
|
||||||
"""
|
"""
|
||||||
expected = """import libcst as cst
|
expected = """import libcst as cst
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
class NewClass:
|
class NewClass:
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
self.name = str(name)
|
self.name = str(name)
|
||||||
def __call__(self, value):
|
def __call__(self, value):
|
||||||
return "I am still old"
|
return "I am still old"
|
||||||
def new_function2(value):
|
def new_function2(value):
|
||||||
return value
|
return cst.ensure_type(value, str)
|
||||||
|
|
||||||
def totally_new_function(value):
|
def totally_new_function(value: Optional[str]):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
print("Hello world")
|
print("Hello world")
|
||||||
|
|
@ -705,12 +716,15 @@ print("Hello world")
|
||||||
("NewClass", "__init__"),
|
("NewClass", "__init__"),
|
||||||
("NewClass", "__call__"),
|
("NewClass", "__call__"),
|
||||||
}
|
}
|
||||||
new_code: str = replace_functions_in_file(
|
new_code: str = replace_functions_and_add_imports(
|
||||||
original_code,
|
source_code=original_code,
|
||||||
[function_name],
|
function_names=[function_name],
|
||||||
optim_code,
|
optimized_code=optim_code,
|
||||||
preexisting_functions,
|
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||||
contextual_functions,
|
module_abspath=str(Path(__file__).resolve()),
|
||||||
|
preexisting_functions=preexisting_functions,
|
||||||
|
contextual_functions=contextual_functions,
|
||||||
|
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||||
)
|
)
|
||||||
assert new_code == expected
|
assert new_code == expected
|
||||||
|
|
||||||
|
|
@ -764,11 +778,16 @@ class MainClass:
|
||||||
experiment_id=None,
|
experiment_id=None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
func_top_optimize = FunctionToOptimize(function_name="main_method", file_path=str(file_path),
|
func_top_optimize = FunctionToOptimize(
|
||||||
parents=[FunctionParent("MainClass", "ClassDef")])
|
function_name="main_method",
|
||||||
|
file_path=str(file_path),
|
||||||
|
parents=[FunctionParent("MainClass", "ClassDef")],
|
||||||
|
)
|
||||||
with open(file_path) as f:
|
with open(file_path) as f:
|
||||||
original_code = f.read()
|
original_code = f.read()
|
||||||
code_context = opt.get_code_optimization_context(function_to_optimize=func_top_optimize,
|
code_context = opt.get_code_optimization_context(
|
||||||
project_root=str(file_path.parent),
|
function_to_optimize=func_top_optimize,
|
||||||
original_source_code=original_code).unwrap()
|
project_root=str(file_path.parent),
|
||||||
|
original_source_code=original_code,
|
||||||
|
).unwrap()
|
||||||
assert code_context.code_to_optimize_with_helpers == get_code_output
|
assert code_context.code_to_optimize_with_helpers == get_code_output
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue