mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Remove the older imports addition logic, which should prevent duplicate imports.
Also reset the original file if there is a keyboard interrupt during the optimization loop.
This commit is contained in:
parent
96a2a02e26
commit
74db766a43
3 changed files with 243 additions and 204 deletions
|
|
@ -26,7 +26,6 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
self.optim_body: FunctionDef | None = None
|
||||
self.optim_new_class_functions: list[cst.FunctionDef] = []
|
||||
self.optim_new_functions: list[cst.FunctionDef] = []
|
||||
self.optim_imports: list[cst.SimpleStatementLine] = []
|
||||
self.preexisting_functions = preexisting_functions
|
||||
self.contextual_functions = contextual_functions.union(
|
||||
{(self.class_name, self.function_name)},
|
||||
|
|
@ -64,10 +63,6 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
):
|
||||
self.optim_new_class_functions.append(child_node)
|
||||
|
||||
def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine) -> None:
|
||||
if isinstance(original_node.body[0], (cst.Import, cst.ImportFrom)):
|
||||
self.optim_imports.append(original_node)
|
||||
|
||||
|
||||
class OptimFunctionReplacer(cst.CSTTransformer):
|
||||
def __init__(
|
||||
|
|
@ -75,7 +70,6 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
function_name: str,
|
||||
optim_body: cst.FunctionDef,
|
||||
optim_new_class_functions: list[cst.FunctionDef],
|
||||
optim_imports: list[cst.SimpleStatementLine],
|
||||
optim_new_functions: list[cst.FunctionDef],
|
||||
class_name: str | None = None,
|
||||
) -> None:
|
||||
|
|
@ -83,7 +77,6 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
self.function_name = function_name
|
||||
self.optim_body = optim_body
|
||||
self.optim_new_class_functions = optim_new_class_functions
|
||||
self.optim_new_imports = optim_imports
|
||||
self.optim_new_functions = optim_new_functions
|
||||
self.class_name = class_name
|
||||
self.depth: int = 0
|
||||
|
|
@ -126,10 +119,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
if len(self.optim_new_imports) == 0:
|
||||
node = updated_node
|
||||
else:
|
||||
node = updated_node.with_changes(body=(*self.optim_new_imports, *updated_node.body))
|
||||
node = updated_node
|
||||
max_function_index = None
|
||||
class_index = None
|
||||
for index, _node in enumerate(node.body):
|
||||
|
|
@ -190,13 +180,11 @@ def replace_functions_in_file(
|
|||
continue
|
||||
if visitor.optim_body is None:
|
||||
raise ValueError(f"Did not find the function {function_name} in the optimized code")
|
||||
optim_imports: list[cst.SimpleStatementLine] = [] if i > 0 else visitor.optim_imports
|
||||
|
||||
transformer = OptimFunctionReplacer(
|
||||
visitor.function_name,
|
||||
visitor.optim_body,
|
||||
visitor.optim_new_class_functions,
|
||||
optim_imports,
|
||||
visitor.optim_new_functions,
|
||||
class_name=class_name,
|
||||
)
|
||||
|
|
@ -207,6 +195,31 @@ def replace_functions_in_file(
|
|||
return source_code
|
||||
|
||||
|
||||
def replace_functions_and_add_imports(
|
||||
source_code: str,
|
||||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
file_path_of_module_with_function_to_optimize: str,
|
||||
module_abspath: str,
|
||||
preexisting_functions: list[str],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
project_root_path: str,
|
||||
) -> str:
|
||||
return add_needed_imports_from_module(
|
||||
optimized_code,
|
||||
replace_functions_in_file(
|
||||
source_code,
|
||||
function_names,
|
||||
optimized_code,
|
||||
preexisting_functions,
|
||||
contextual_functions,
|
||||
),
|
||||
file_path_of_module_with_function_to_optimize,
|
||||
module_abspath,
|
||||
project_root_path,
|
||||
)
|
||||
|
||||
|
||||
def replace_function_definitions_in_module(
|
||||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
|
|
@ -219,17 +232,14 @@ def replace_function_definitions_in_module(
|
|||
file: IO[str]
|
||||
with open(module_abspath, encoding="utf8") as file:
|
||||
source_code: str = file.read()
|
||||
new_code: str = add_needed_imports_from_module(
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code,
|
||||
function_names,
|
||||
optimized_code,
|
||||
replace_functions_in_file(
|
||||
source_code,
|
||||
function_names,
|
||||
optimized_code,
|
||||
preexisting_functions,
|
||||
contextual_functions,
|
||||
),
|
||||
file_path_of_module_with_function_to_optimize,
|
||||
module_abspath,
|
||||
preexisting_functions,
|
||||
contextual_functions,
|
||||
project_root_path,
|
||||
)
|
||||
with open(module_abspath, "w", encoding="utf8") as file:
|
||||
|
|
|
|||
|
|
@ -173,7 +173,6 @@ class Optimizer:
|
|||
elif self.args.all:
|
||||
logging.info("✨ All functions have been optimized! ✨")
|
||||
finally:
|
||||
# TODO @afik.cohen: Also revert the file/function being optimized if the process did not succeed
|
||||
for test_file in self.instrumented_unittests_created:
|
||||
pathlib.Path(test_file).unlink(missing_ok=True)
|
||||
for test_file in self.test_files_created:
|
||||
|
|
@ -384,110 +383,121 @@ class Optimizer:
|
|||
logging.info(
|
||||
f"Determining best optimized candidate (out of {len(candidates)}) for {function_to_optimize.qualified_name} ...",
|
||||
)
|
||||
for i, candidate in enumerate(candidates):
|
||||
j = i + 1
|
||||
if candidate.source_code is None:
|
||||
continue
|
||||
# remove left overs from previous run
|
||||
pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.bin")).unlink(
|
||||
missing_ok=True,
|
||||
)
|
||||
pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.sqlite")).unlink(
|
||||
missing_ok=True,
|
||||
)
|
||||
logging.info(f"Optimized candidate {j}/{len(candidates)}:")
|
||||
logging.info(candidate.source_code)
|
||||
try:
|
||||
replace_function_definitions_in_module(
|
||||
function_names=[function_to_optimize.qualified_name],
|
||||
optimized_code=candidate.source_code,
|
||||
file_path_of_module_with_function_to_optimize=function_to_optimize.file_path,
|
||||
module_abspath=function_to_optimize.file_path,
|
||||
preexisting_functions=code_context.preexisting_functions,
|
||||
contextual_functions=code_context.contextual_dunder_methods,
|
||||
project_root_path=self.args.project_root,
|
||||
try:
|
||||
for i, candidate in enumerate(candidates):
|
||||
j = i + 1
|
||||
if candidate.source_code is None:
|
||||
continue
|
||||
# remove left overs from previous run
|
||||
pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.bin")).unlink(
|
||||
missing_ok=True,
|
||||
)
|
||||
for (
|
||||
module_abspath,
|
||||
qualified_names,
|
||||
) in helper_functions_by_module_abspath.items():
|
||||
pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.sqlite")).unlink(
|
||||
missing_ok=True,
|
||||
)
|
||||
logging.info(f"Optimized candidate {j}/{len(candidates)}:")
|
||||
logging.info(candidate.source_code)
|
||||
try:
|
||||
replace_function_definitions_in_module(
|
||||
function_names=list(qualified_names),
|
||||
function_names=[function_to_optimize.qualified_name],
|
||||
optimized_code=candidate.source_code,
|
||||
file_path_of_module_with_function_to_optimize=function_to_optimize.file_path,
|
||||
module_abspath=module_abspath,
|
||||
preexisting_functions=[],
|
||||
module_abspath=function_to_optimize.file_path,
|
||||
preexisting_functions=code_context.preexisting_functions,
|
||||
contextual_functions=code_context.contextual_dunder_methods,
|
||||
project_root_path=self.args.project_root,
|
||||
)
|
||||
except (
|
||||
ValueError,
|
||||
SyntaxError,
|
||||
cst.ParserSyntaxError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
logging.error(e) # noqa: TRY400
|
||||
for (
|
||||
module_abspath,
|
||||
qualified_names,
|
||||
) in helper_functions_by_module_abspath.items():
|
||||
replace_function_definitions_in_module(
|
||||
function_names=list(qualified_names),
|
||||
optimized_code=candidate.source_code,
|
||||
file_path_of_module_with_function_to_optimize=function_to_optimize.file_path,
|
||||
module_abspath=module_abspath,
|
||||
preexisting_functions=[],
|
||||
contextual_functions=code_context.contextual_dunder_methods,
|
||||
project_root_path=self.args.project_root,
|
||||
)
|
||||
except (
|
||||
ValueError,
|
||||
SyntaxError,
|
||||
cst.ParserSyntaxError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
logging.error(e) # noqa: TRY400
|
||||
self.write_code_and_helpers(
|
||||
original_code,
|
||||
original_helper_code,
|
||||
function_to_optimize.file_path,
|
||||
helper_functions_by_module_abspath,
|
||||
)
|
||||
continue
|
||||
|
||||
# Run generated tests if at least one of them passed
|
||||
run_generated_tests = False
|
||||
if original_code_baseline.generated_test_results:
|
||||
for test_result in original_code_baseline.generated_test_results.test_results:
|
||||
if test_result.did_pass:
|
||||
run_generated_tests = True
|
||||
break
|
||||
|
||||
run_results = self.run_optimized_candidate(
|
||||
optimization_index=j,
|
||||
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
|
||||
overall_original_test_results=original_code_baseline.overall_test_results,
|
||||
original_existing_test_results=original_code_baseline.existing_test_results,
|
||||
original_generated_test_results=original_code_baseline.generated_test_results,
|
||||
generated_tests_path=generated_tests_path,
|
||||
best_runtime_until_now=best_runtime_until_now,
|
||||
tests_in_file=only_run_this_test_function,
|
||||
run_generated_tests=run_generated_tests,
|
||||
)
|
||||
if not is_successful(run_results):
|
||||
optimized_runtimes[candidate.optimization_id] = None
|
||||
is_correct[candidate.optimization_id] = False
|
||||
speedup_ratios[candidate.optimization_id] = None
|
||||
else:
|
||||
candidate_result: OptimizedCandidateResult = run_results.unwrap()
|
||||
best_test_runtime = candidate_result.best_test_runtime
|
||||
optimized_runtimes[candidate.optimization_id] = best_test_runtime
|
||||
is_correct[candidate.optimization_id] = True
|
||||
speedup_ratios[candidate.optimization_id] = (
|
||||
original_code_baseline.runtime - best_test_runtime
|
||||
) / best_test_runtime
|
||||
logging.info(
|
||||
f"Candidate runtime measured over {candidate_result.times_run} run{'s' if candidate_result.times_run > 1 else ''}: "
|
||||
f"{humanize_runtime(best_test_runtime)}, speedup ratio = "
|
||||
f"{((original_code_baseline.runtime - best_test_runtime) / best_test_runtime):.3f}",
|
||||
)
|
||||
|
||||
if speedup_critic(candidate_result, original_code_baseline.runtime, best_runtime_until_now):
|
||||
best_optimization = BestOptimization(
|
||||
candidate=candidate,
|
||||
helper_functions=code_context.helper_functions,
|
||||
runtime=best_test_runtime,
|
||||
winning_test_results=candidate_result.best_test_results,
|
||||
)
|
||||
best_runtime_until_now = best_test_runtime
|
||||
|
||||
self.write_code_and_helpers(
|
||||
original_code,
|
||||
original_helper_code,
|
||||
function_to_optimize.file_path,
|
||||
helper_functions_by_module_abspath,
|
||||
)
|
||||
continue
|
||||
|
||||
# Run generated tests if at least one of them passed
|
||||
run_generated_tests = False
|
||||
if original_code_baseline.generated_test_results:
|
||||
for test_result in original_code_baseline.generated_test_results.test_results:
|
||||
if test_result.did_pass:
|
||||
run_generated_tests = True
|
||||
break
|
||||
|
||||
run_results = self.run_optimized_candidate(
|
||||
optimization_index=j,
|
||||
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
|
||||
overall_original_test_results=original_code_baseline.overall_test_results,
|
||||
original_existing_test_results=original_code_baseline.existing_test_results,
|
||||
original_generated_test_results=original_code_baseline.generated_test_results,
|
||||
generated_tests_path=generated_tests_path,
|
||||
best_runtime_until_now=best_runtime_until_now,
|
||||
tests_in_file=only_run_this_test_function,
|
||||
run_generated_tests=run_generated_tests,
|
||||
)
|
||||
if not is_successful(run_results):
|
||||
optimized_runtimes[candidate.optimization_id] = None
|
||||
is_correct[candidate.optimization_id] = False
|
||||
speedup_ratios[candidate.optimization_id] = None
|
||||
else:
|
||||
candidate_result: OptimizedCandidateResult = run_results.unwrap()
|
||||
best_test_runtime = candidate_result.best_test_runtime
|
||||
optimized_runtimes[candidate.optimization_id] = best_test_runtime
|
||||
is_correct[candidate.optimization_id] = True
|
||||
speedup_ratios[candidate.optimization_id] = (
|
||||
original_code_baseline.runtime - best_test_runtime
|
||||
) / best_test_runtime
|
||||
logging.info(
|
||||
f"Candidate runtime measured over {candidate_result.times_run} run{'s' if candidate_result.times_run > 1 else ''}: "
|
||||
f"{humanize_runtime(best_test_runtime)}, speedup ratio = "
|
||||
f"{((original_code_baseline.runtime - best_test_runtime) / best_test_runtime):.3f}",
|
||||
)
|
||||
|
||||
if speedup_critic(candidate_result, original_code_baseline.runtime, best_runtime_until_now):
|
||||
best_optimization = BestOptimization(
|
||||
candidate=candidate,
|
||||
helper_functions=code_context.helper_functions,
|
||||
runtime=best_test_runtime,
|
||||
winning_test_results=candidate_result.best_test_results,
|
||||
)
|
||||
best_runtime_until_now = best_test_runtime
|
||||
|
||||
logging.info("----------------")
|
||||
except KeyboardInterrupt as e:
|
||||
self.write_code_and_helpers(
|
||||
original_code,
|
||||
original_helper_code,
|
||||
function_to_optimize.file_path,
|
||||
helper_functions_by_module_abspath,
|
||||
)
|
||||
logging.info("----------------")
|
||||
logging.error(f"Optimization interrupted: {e}")
|
||||
raise e
|
||||
|
||||
self.aiservice_client.log_results(
|
||||
function_trace_id=function_trace_id,
|
||||
speedup_ratio=speedup_ratios,
|
||||
|
|
|
|||
|
|
@ -4,10 +4,8 @@ import os
|
|||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from returns.pipeline import is_successful
|
||||
|
||||
from codeflash.code_utils.code_replacer import replace_functions_in_file
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionParent
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
|
||||
|
|
@ -37,9 +35,7 @@ class NewClass:
|
|||
|
||||
print("Hello world")
|
||||
"""
|
||||
expected = """import libcst as cst
|
||||
from typing import Optional
|
||||
class NewClass:
|
||||
expected = """class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def new_function(self, value):
|
||||
|
|
@ -56,12 +52,15 @@ print("Hello world")
|
|||
function_name: str = "NewClass.new_function"
|
||||
preexisting_functions: list[str] = ["new_function"]
|
||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
[function_name],
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
contextual_functions,
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -95,9 +94,7 @@ class NewClass:
|
|||
|
||||
print("Hello world")
|
||||
"""
|
||||
expected = """import libcst as cst
|
||||
from typing import Optional
|
||||
from OtherModule import other_function
|
||||
expected = """from OtherModule import other_function
|
||||
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
|
|
@ -116,12 +113,15 @@ print("Hello world")
|
|||
function_name: str = "NewClass.new_function"
|
||||
preexisting_functions: list[str] = ["new_function", "other_function"]
|
||||
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
[function_name],
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
contextual_functions,
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -139,7 +139,7 @@ def other_function(st):
|
|||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def new_function(self, value):
|
||||
def new_function(self, value: cst.Name):
|
||||
return other_function(self.name)
|
||||
def new_function2(value):
|
||||
return value
|
||||
|
|
@ -158,10 +158,7 @@ def other_function(st):
|
|||
|
||||
print("Salut monde")
|
||||
"""
|
||||
expected = """import libcst as cst
|
||||
from typing import Optional
|
||||
import libcst as cst
|
||||
from typing import Mandatory
|
||||
expected = """from typing import Mandatory
|
||||
|
||||
print("Au revoir")
|
||||
|
||||
|
|
@ -177,12 +174,15 @@ print("Salut monde")
|
|||
function_names: list[str] = ["module.other_function"]
|
||||
preexisting_functions: list[str] = []
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
function_names,
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
contextual_functions,
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -194,7 +194,7 @@ from typing import Optional
|
|||
def totally_new_function(value):
|
||||
return value
|
||||
|
||||
def yet_another_function(values):
|
||||
def yet_another_function(values: Optional[str]):
|
||||
return len(values) + 2
|
||||
|
||||
def other_function(st):
|
||||
|
|
@ -222,14 +222,11 @@ def other_function(st):
|
|||
|
||||
print("Salut monde")
|
||||
"""
|
||||
expected = """import libcst as cst
|
||||
from typing import Optional
|
||||
import libcst as cst
|
||||
from typing import Mandatory
|
||||
expected = """from typing import Optional, Mandatory
|
||||
|
||||
print("Au revoir")
|
||||
|
||||
def yet_another_function(values):
|
||||
def yet_another_function(values: Optional[str]):
|
||||
return len(values) + 2
|
||||
|
||||
def other_function(st):
|
||||
|
|
@ -241,12 +238,15 @@ print("Salut monde")
|
|||
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
|
||||
preexisting_functions: list[str] = []
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
function_names,
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
contextual_functions,
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -291,12 +291,15 @@ def supersort(doink):
|
|||
function_names: list[str] = ["sorter_deps"]
|
||||
preexisting_functions: list[str] = ["sorter_deps"]
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
function_names,
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
contextual_functions,
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -338,10 +341,7 @@ def blab(st):
|
|||
|
||||
print("Not cool")
|
||||
"""
|
||||
expected_main = """import libcst as cst
|
||||
from typing import Optional
|
||||
import libcst as cst
|
||||
from typing import Mandatory
|
||||
expected_main = """from typing import Mandatory
|
||||
from helper import blob
|
||||
|
||||
print("Au revoir")
|
||||
|
|
@ -355,9 +355,7 @@ def other_function(st):
|
|||
print("Salut monde")
|
||||
"""
|
||||
|
||||
expected_helper = """import libcst as cst
|
||||
from typing import Optional
|
||||
import numpy as np
|
||||
expected_helper = """import numpy as np
|
||||
|
||||
print("Cool")
|
||||
|
||||
|
|
@ -369,21 +367,27 @@ def blab(st):
|
|||
|
||||
print("Not cool")
|
||||
"""
|
||||
new_main_code: str = replace_functions_in_file(
|
||||
original_code_main,
|
||||
["other_function"],
|
||||
optim_code,
|
||||
["other_function", "yet_another_function", "blob"],
|
||||
set(),
|
||||
new_main_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code_main,
|
||||
function_names=["other_function"],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=["other_function", "yet_another_function", "blob"],
|
||||
contextual_functions=set(),
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
assert new_main_code == expected_main
|
||||
|
||||
new_helper_code: str = replace_functions_in_file(
|
||||
original_code_helper,
|
||||
["blob"],
|
||||
optim_code,
|
||||
[],
|
||||
set(),
|
||||
new_helper_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code_helper,
|
||||
function_names=["blob"],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=[],
|
||||
contextual_functions=set(),
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
assert new_helper_code == expected_helper
|
||||
|
||||
|
|
@ -578,12 +582,15 @@ class CacheConfig(BaseConfig):
|
|||
("CacheConfig", "__init__"),
|
||||
("CacheInitConfig", "__init__"),
|
||||
}
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
function_names,
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
contextual_functions,
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -649,12 +656,15 @@ def test_test_libcst_code_replacement8() -> None:
|
|||
"_hamming_distance",
|
||||
]
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
function_names,
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
contextual_functions,
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -663,7 +673,7 @@ def test_test_libcst_code_replacement9() -> None:
|
|||
optim_code = """import libcst as cst
|
||||
from typing import Optional
|
||||
|
||||
def totally_new_function(value):
|
||||
def totally_new_function(value: Optional[str]):
|
||||
return value
|
||||
|
||||
class NewClass:
|
||||
|
|
@ -672,7 +682,7 @@ class NewClass:
|
|||
def __call__(self, value):
|
||||
return self.name
|
||||
def new_function2(value):
|
||||
return value
|
||||
return cst.ensure_type(value, str)
|
||||
"""
|
||||
|
||||
original_code = """class NewClass:
|
||||
|
|
@ -685,15 +695,16 @@ print("Hello world")
|
|||
"""
|
||||
expected = """import libcst as cst
|
||||
from typing import Optional
|
||||
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = str(name)
|
||||
def __call__(self, value):
|
||||
return "I am still old"
|
||||
def new_function2(value):
|
||||
return value
|
||||
return cst.ensure_type(value, str)
|
||||
|
||||
def totally_new_function(value):
|
||||
def totally_new_function(value: Optional[str]):
|
||||
return value
|
||||
|
||||
print("Hello world")
|
||||
|
|
@ -705,12 +716,15 @@ print("Hello world")
|
|||
("NewClass", "__init__"),
|
||||
("NewClass", "__call__"),
|
||||
}
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
[function_name],
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
contextual_functions,
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
preexisting_functions=preexisting_functions,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -764,11 +778,16 @@ class MainClass:
|
|||
experiment_id=None,
|
||||
),
|
||||
)
|
||||
func_top_optimize = FunctionToOptimize(function_name="main_method", file_path=str(file_path),
|
||||
parents=[FunctionParent("MainClass", "ClassDef")])
|
||||
func_top_optimize = FunctionToOptimize(
|
||||
function_name="main_method",
|
||||
file_path=str(file_path),
|
||||
parents=[FunctionParent("MainClass", "ClassDef")],
|
||||
)
|
||||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
code_context = opt.get_code_optimization_context(function_to_optimize=func_top_optimize,
|
||||
project_root=str(file_path.parent),
|
||||
original_source_code=original_code).unwrap()
|
||||
code_context = opt.get_code_optimization_context(
|
||||
function_to_optimize=func_top_optimize,
|
||||
project_root=str(file_path.parent),
|
||||
original_source_code=original_code,
|
||||
).unwrap()
|
||||
assert code_context.code_to_optimize_with_helpers == get_code_output
|
||||
|
|
|
|||
Loading…
Reference in a new issue