Remove the older imports addition logic, which should prevent duplicate imports.

Also reset the original file if there is a keyboard interrupt during the optimization loop.
This commit is contained in:
Saurabh Misra 2024-06-16 19:17:45 -07:00
parent 96a2a02e26
commit 74db766a43
3 changed files with 243 additions and 204 deletions

View file

@ -26,7 +26,6 @@ class OptimFunctionCollector(cst.CSTVisitor):
self.optim_body: FunctionDef | None = None self.optim_body: FunctionDef | None = None
self.optim_new_class_functions: list[cst.FunctionDef] = [] self.optim_new_class_functions: list[cst.FunctionDef] = []
self.optim_new_functions: list[cst.FunctionDef] = [] self.optim_new_functions: list[cst.FunctionDef] = []
self.optim_imports: list[cst.SimpleStatementLine] = []
self.preexisting_functions = preexisting_functions self.preexisting_functions = preexisting_functions
self.contextual_functions = contextual_functions.union( self.contextual_functions = contextual_functions.union(
{(self.class_name, self.function_name)}, {(self.class_name, self.function_name)},
@ -64,10 +63,6 @@ class OptimFunctionCollector(cst.CSTVisitor):
): ):
self.optim_new_class_functions.append(child_node) self.optim_new_class_functions.append(child_node)
def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine) -> None:
if isinstance(original_node.body[0], (cst.Import, cst.ImportFrom)):
self.optim_imports.append(original_node)
class OptimFunctionReplacer(cst.CSTTransformer): class OptimFunctionReplacer(cst.CSTTransformer):
def __init__( def __init__(
@ -75,7 +70,6 @@ class OptimFunctionReplacer(cst.CSTTransformer):
function_name: str, function_name: str,
optim_body: cst.FunctionDef, optim_body: cst.FunctionDef,
optim_new_class_functions: list[cst.FunctionDef], optim_new_class_functions: list[cst.FunctionDef],
optim_imports: list[cst.SimpleStatementLine],
optim_new_functions: list[cst.FunctionDef], optim_new_functions: list[cst.FunctionDef],
class_name: str | None = None, class_name: str | None = None,
) -> None: ) -> None:
@ -83,7 +77,6 @@ class OptimFunctionReplacer(cst.CSTTransformer):
self.function_name = function_name self.function_name = function_name
self.optim_body = optim_body self.optim_body = optim_body
self.optim_new_class_functions = optim_new_class_functions self.optim_new_class_functions = optim_new_class_functions
self.optim_new_imports = optim_imports
self.optim_new_functions = optim_new_functions self.optim_new_functions = optim_new_functions
self.class_name = class_name self.class_name = class_name
self.depth: int = 0 self.depth: int = 0
@ -126,10 +119,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
return updated_node return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
if len(self.optim_new_imports) == 0: node = updated_node
node = updated_node
else:
node = updated_node.with_changes(body=(*self.optim_new_imports, *updated_node.body))
max_function_index = None max_function_index = None
class_index = None class_index = None
for index, _node in enumerate(node.body): for index, _node in enumerate(node.body):
@ -190,13 +180,11 @@ def replace_functions_in_file(
continue continue
if visitor.optim_body is None: if visitor.optim_body is None:
raise ValueError(f"Did not find the function {function_name} in the optimized code") raise ValueError(f"Did not find the function {function_name} in the optimized code")
optim_imports: list[cst.SimpleStatementLine] = [] if i > 0 else visitor.optim_imports
transformer = OptimFunctionReplacer( transformer = OptimFunctionReplacer(
visitor.function_name, visitor.function_name,
visitor.optim_body, visitor.optim_body,
visitor.optim_new_class_functions, visitor.optim_new_class_functions,
optim_imports,
visitor.optim_new_functions, visitor.optim_new_functions,
class_name=class_name, class_name=class_name,
) )
@ -207,6 +195,31 @@ def replace_functions_in_file(
return source_code return source_code
def replace_functions_and_add_imports(
source_code: str,
function_names: list[str],
optimized_code: str,
file_path_of_module_with_function_to_optimize: str,
module_abspath: str,
preexisting_functions: list[str],
contextual_functions: set[tuple[str, str]],
project_root_path: str,
) -> str:
return add_needed_imports_from_module(
optimized_code,
replace_functions_in_file(
source_code,
function_names,
optimized_code,
preexisting_functions,
contextual_functions,
),
file_path_of_module_with_function_to_optimize,
module_abspath,
project_root_path,
)
def replace_function_definitions_in_module( def replace_function_definitions_in_module(
function_names: list[str], function_names: list[str],
optimized_code: str, optimized_code: str,
@ -219,17 +232,14 @@ def replace_function_definitions_in_module(
file: IO[str] file: IO[str]
with open(module_abspath, encoding="utf8") as file: with open(module_abspath, encoding="utf8") as file:
source_code: str = file.read() source_code: str = file.read()
new_code: str = add_needed_imports_from_module( new_code: str = replace_functions_and_add_imports(
source_code,
function_names,
optimized_code, optimized_code,
replace_functions_in_file(
source_code,
function_names,
optimized_code,
preexisting_functions,
contextual_functions,
),
file_path_of_module_with_function_to_optimize, file_path_of_module_with_function_to_optimize,
module_abspath, module_abspath,
preexisting_functions,
contextual_functions,
project_root_path, project_root_path,
) )
with open(module_abspath, "w", encoding="utf8") as file: with open(module_abspath, "w", encoding="utf8") as file:

View file

@ -173,7 +173,6 @@ class Optimizer:
elif self.args.all: elif self.args.all:
logging.info("✨ All functions have been optimized! ✨") logging.info("✨ All functions have been optimized! ✨")
finally: finally:
# TODO @afik.cohen: Also revert the file/function being optimized if the process did not succeed
for test_file in self.instrumented_unittests_created: for test_file in self.instrumented_unittests_created:
pathlib.Path(test_file).unlink(missing_ok=True) pathlib.Path(test_file).unlink(missing_ok=True)
for test_file in self.test_files_created: for test_file in self.test_files_created:
@ -384,110 +383,121 @@ class Optimizer:
logging.info( logging.info(
f"Determining best optimized candidate (out of {len(candidates)}) for {function_to_optimize.qualified_name} ...", f"Determining best optimized candidate (out of {len(candidates)}) for {function_to_optimize.qualified_name} ...",
) )
for i, candidate in enumerate(candidates): try:
j = i + 1 for i, candidate in enumerate(candidates):
if candidate.source_code is None: j = i + 1
continue if candidate.source_code is None:
# remove left overs from previous run continue
pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.bin")).unlink( # remove left overs from previous run
missing_ok=True, pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.bin")).unlink(
) missing_ok=True,
pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.sqlite")).unlink(
missing_ok=True,
)
logging.info(f"Optimized candidate {j}/{len(candidates)}:")
logging.info(candidate.source_code)
try:
replace_function_definitions_in_module(
function_names=[function_to_optimize.qualified_name],
optimized_code=candidate.source_code,
file_path_of_module_with_function_to_optimize=function_to_optimize.file_path,
module_abspath=function_to_optimize.file_path,
preexisting_functions=code_context.preexisting_functions,
contextual_functions=code_context.contextual_dunder_methods,
project_root_path=self.args.project_root,
) )
for ( pathlib.Path(get_run_tmp_file(f"test_return_values_{j}.sqlite")).unlink(
module_abspath, missing_ok=True,
qualified_names, )
) in helper_functions_by_module_abspath.items(): logging.info(f"Optimized candidate {j}/{len(candidates)}:")
logging.info(candidate.source_code)
try:
replace_function_definitions_in_module( replace_function_definitions_in_module(
function_names=list(qualified_names), function_names=[function_to_optimize.qualified_name],
optimized_code=candidate.source_code, optimized_code=candidate.source_code,
file_path_of_module_with_function_to_optimize=function_to_optimize.file_path, file_path_of_module_with_function_to_optimize=function_to_optimize.file_path,
module_abspath=module_abspath, module_abspath=function_to_optimize.file_path,
preexisting_functions=[], preexisting_functions=code_context.preexisting_functions,
contextual_functions=code_context.contextual_dunder_methods, contextual_functions=code_context.contextual_dunder_methods,
project_root_path=self.args.project_root, project_root_path=self.args.project_root,
) )
except ( for (
ValueError, module_abspath,
SyntaxError, qualified_names,
cst.ParserSyntaxError, ) in helper_functions_by_module_abspath.items():
AttributeError, replace_function_definitions_in_module(
) as e: function_names=list(qualified_names),
logging.error(e) # noqa: TRY400 optimized_code=candidate.source_code,
file_path_of_module_with_function_to_optimize=function_to_optimize.file_path,
module_abspath=module_abspath,
preexisting_functions=[],
contextual_functions=code_context.contextual_dunder_methods,
project_root_path=self.args.project_root,
)
except (
ValueError,
SyntaxError,
cst.ParserSyntaxError,
AttributeError,
) as e:
logging.error(e) # noqa: TRY400
self.write_code_and_helpers(
original_code,
original_helper_code,
function_to_optimize.file_path,
helper_functions_by_module_abspath,
)
continue
# Run generated tests if at least one of them passed
run_generated_tests = False
if original_code_baseline.generated_test_results:
for test_result in original_code_baseline.generated_test_results.test_results:
if test_result.did_pass:
run_generated_tests = True
break
run_results = self.run_optimized_candidate(
optimization_index=j,
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
overall_original_test_results=original_code_baseline.overall_test_results,
original_existing_test_results=original_code_baseline.existing_test_results,
original_generated_test_results=original_code_baseline.generated_test_results,
generated_tests_path=generated_tests_path,
best_runtime_until_now=best_runtime_until_now,
tests_in_file=only_run_this_test_function,
run_generated_tests=run_generated_tests,
)
if not is_successful(run_results):
optimized_runtimes[candidate.optimization_id] = None
is_correct[candidate.optimization_id] = False
speedup_ratios[candidate.optimization_id] = None
else:
candidate_result: OptimizedCandidateResult = run_results.unwrap()
best_test_runtime = candidate_result.best_test_runtime
optimized_runtimes[candidate.optimization_id] = best_test_runtime
is_correct[candidate.optimization_id] = True
speedup_ratios[candidate.optimization_id] = (
original_code_baseline.runtime - best_test_runtime
) / best_test_runtime
logging.info(
f"Candidate runtime measured over {candidate_result.times_run} run{'s' if candidate_result.times_run > 1 else ''}: "
f"{humanize_runtime(best_test_runtime)}, speedup ratio = "
f"{((original_code_baseline.runtime - best_test_runtime) / best_test_runtime):.3f}",
)
if speedup_critic(candidate_result, original_code_baseline.runtime, best_runtime_until_now):
best_optimization = BestOptimization(
candidate=candidate,
helper_functions=code_context.helper_functions,
runtime=best_test_runtime,
winning_test_results=candidate_result.best_test_results,
)
best_runtime_until_now = best_test_runtime
self.write_code_and_helpers( self.write_code_and_helpers(
original_code, original_code,
original_helper_code, original_helper_code,
function_to_optimize.file_path, function_to_optimize.file_path,
helper_functions_by_module_abspath, helper_functions_by_module_abspath,
) )
continue logging.info("----------------")
except KeyboardInterrupt as e:
# Run generated tests if at least one of them passed
run_generated_tests = False
if original_code_baseline.generated_test_results:
for test_result in original_code_baseline.generated_test_results.test_results:
if test_result.did_pass:
run_generated_tests = True
break
run_results = self.run_optimized_candidate(
optimization_index=j,
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
overall_original_test_results=original_code_baseline.overall_test_results,
original_existing_test_results=original_code_baseline.existing_test_results,
original_generated_test_results=original_code_baseline.generated_test_results,
generated_tests_path=generated_tests_path,
best_runtime_until_now=best_runtime_until_now,
tests_in_file=only_run_this_test_function,
run_generated_tests=run_generated_tests,
)
if not is_successful(run_results):
optimized_runtimes[candidate.optimization_id] = None
is_correct[candidate.optimization_id] = False
speedup_ratios[candidate.optimization_id] = None
else:
candidate_result: OptimizedCandidateResult = run_results.unwrap()
best_test_runtime = candidate_result.best_test_runtime
optimized_runtimes[candidate.optimization_id] = best_test_runtime
is_correct[candidate.optimization_id] = True
speedup_ratios[candidate.optimization_id] = (
original_code_baseline.runtime - best_test_runtime
) / best_test_runtime
logging.info(
f"Candidate runtime measured over {candidate_result.times_run} run{'s' if candidate_result.times_run > 1 else ''}: "
f"{humanize_runtime(best_test_runtime)}, speedup ratio = "
f"{((original_code_baseline.runtime - best_test_runtime) / best_test_runtime):.3f}",
)
if speedup_critic(candidate_result, original_code_baseline.runtime, best_runtime_until_now):
best_optimization = BestOptimization(
candidate=candidate,
helper_functions=code_context.helper_functions,
runtime=best_test_runtime,
winning_test_results=candidate_result.best_test_results,
)
best_runtime_until_now = best_test_runtime
self.write_code_and_helpers( self.write_code_and_helpers(
original_code, original_code,
original_helper_code, original_helper_code,
function_to_optimize.file_path, function_to_optimize.file_path,
helper_functions_by_module_abspath, helper_functions_by_module_abspath,
) )
logging.info("----------------") logging.error(f"Optimization interrupted: {e}")
raise e
self.aiservice_client.log_results( self.aiservice_client.log_results(
function_trace_id=function_trace_id, function_trace_id=function_trace_id,
speedup_ratio=speedup_ratios, speedup_ratio=speedup_ratios,

View file

@ -4,10 +4,8 @@ import os
from argparse import Namespace from argparse import Namespace
from pathlib import Path from pathlib import Path
from returns.pipeline import is_successful from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
from codeflash.code_utils.code_replacer import replace_functions_in_file
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionParent
from codeflash.optimization.optimizer import Optimizer from codeflash.optimization.optimizer import Optimizer
os.environ["CODEFLASH_API_KEY"] = "cf-test-key" os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
@ -37,9 +35,7 @@ class NewClass:
print("Hello world") print("Hello world")
""" """
expected = """import libcst as cst expected = """class NewClass:
from typing import Optional
class NewClass:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
def new_function(self, value): def new_function(self, value):
@ -56,12 +52,15 @@ print("Hello world")
function_name: str = "NewClass.new_function" function_name: str = "NewClass.new_function"
preexisting_functions: list[str] = ["new_function"] preexisting_functions: list[str] = ["new_function"]
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")} contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
new_code: str = replace_functions_in_file( new_code: str = replace_functions_and_add_imports(
original_code, source_code=original_code,
[function_name], function_names=[function_name],
optim_code, optimized_code=optim_code,
preexisting_functions, file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
contextual_functions, module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
) )
assert new_code == expected assert new_code == expected
@ -95,9 +94,7 @@ class NewClass:
print("Hello world") print("Hello world")
""" """
expected = """import libcst as cst expected = """from OtherModule import other_function
from typing import Optional
from OtherModule import other_function
class NewClass: class NewClass:
def __init__(self, name): def __init__(self, name):
@ -116,12 +113,15 @@ print("Hello world")
function_name: str = "NewClass.new_function" function_name: str = "NewClass.new_function"
preexisting_functions: list[str] = ["new_function", "other_function"] preexisting_functions: list[str] = ["new_function", "other_function"]
contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")} contextual_functions: set[tuple[str, str]] = {("NewClass", "__init__")}
new_code: str = replace_functions_in_file( new_code: str = replace_functions_and_add_imports(
original_code, source_code=original_code,
[function_name], function_names=[function_name],
optim_code, optimized_code=optim_code,
preexisting_functions, file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
contextual_functions, module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
) )
assert new_code == expected assert new_code == expected
@ -139,7 +139,7 @@ def other_function(st):
class NewClass: class NewClass:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
def new_function(self, value): def new_function(self, value: cst.Name):
return other_function(self.name) return other_function(self.name)
def new_function2(value): def new_function2(value):
return value return value
@ -158,10 +158,7 @@ def other_function(st):
print("Salut monde") print("Salut monde")
""" """
expected = """import libcst as cst expected = """from typing import Mandatory
from typing import Optional
import libcst as cst
from typing import Mandatory
print("Au revoir") print("Au revoir")
@ -177,12 +174,15 @@ print("Salut monde")
function_names: list[str] = ["module.other_function"] function_names: list[str] = ["module.other_function"]
preexisting_functions: list[str] = [] preexisting_functions: list[str] = []
contextual_functions: set[tuple[str, str]] = set() contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file( new_code: str = replace_functions_and_add_imports(
original_code, source_code=original_code,
function_names, function_names=function_names,
optim_code, optimized_code=optim_code,
preexisting_functions, file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
contextual_functions, module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
) )
assert new_code == expected assert new_code == expected
@ -194,7 +194,7 @@ from typing import Optional
def totally_new_function(value): def totally_new_function(value):
return value return value
def yet_another_function(values): def yet_another_function(values: Optional[str]):
return len(values) + 2 return len(values) + 2
def other_function(st): def other_function(st):
@ -222,14 +222,11 @@ def other_function(st):
print("Salut monde") print("Salut monde")
""" """
expected = """import libcst as cst expected = """from typing import Optional, Mandatory
from typing import Optional
import libcst as cst
from typing import Mandatory
print("Au revoir") print("Au revoir")
def yet_another_function(values): def yet_another_function(values: Optional[str]):
return len(values) + 2 return len(values) + 2
def other_function(st): def other_function(st):
@ -241,12 +238,15 @@ print("Salut monde")
function_names: list[str] = ["module.yet_another_function", "module.other_function"] function_names: list[str] = ["module.yet_another_function", "module.other_function"]
preexisting_functions: list[str] = [] preexisting_functions: list[str] = []
contextual_functions: set[tuple[str, str]] = set() contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file( new_code: str = replace_functions_and_add_imports(
original_code, source_code=original_code,
function_names, function_names=function_names,
optim_code, optimized_code=optim_code,
preexisting_functions, file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
contextual_functions, module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
) )
assert new_code == expected assert new_code == expected
@ -291,12 +291,15 @@ def supersort(doink):
function_names: list[str] = ["sorter_deps"] function_names: list[str] = ["sorter_deps"]
preexisting_functions: list[str] = ["sorter_deps"] preexisting_functions: list[str] = ["sorter_deps"]
contextual_functions: set[tuple[str, str]] = set() contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file( new_code: str = replace_functions_and_add_imports(
original_code, source_code=original_code,
function_names, function_names=function_names,
optim_code, optimized_code=optim_code,
preexisting_functions, file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
contextual_functions, module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
) )
assert new_code == expected assert new_code == expected
@ -338,10 +341,7 @@ def blab(st):
print("Not cool") print("Not cool")
""" """
expected_main = """import libcst as cst expected_main = """from typing import Mandatory
from typing import Optional
import libcst as cst
from typing import Mandatory
from helper import blob from helper import blob
print("Au revoir") print("Au revoir")
@ -355,9 +355,7 @@ def other_function(st):
print("Salut monde") print("Salut monde")
""" """
expected_helper = """import libcst as cst expected_helper = """import numpy as np
from typing import Optional
import numpy as np
print("Cool") print("Cool")
@ -369,21 +367,27 @@ def blab(st):
print("Not cool") print("Not cool")
""" """
new_main_code: str = replace_functions_in_file( new_main_code: str = replace_functions_and_add_imports(
original_code_main, source_code=original_code_main,
["other_function"], function_names=["other_function"],
optim_code, optimized_code=optim_code,
["other_function", "yet_another_function", "blob"], file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
set(), module_abspath=str(Path(__file__).resolve()),
preexisting_functions=["other_function", "yet_another_function", "blob"],
contextual_functions=set(),
project_root_path=str(Path(__file__).resolve().parent.resolve()),
) )
assert new_main_code == expected_main assert new_main_code == expected_main
new_helper_code: str = replace_functions_in_file( new_helper_code: str = replace_functions_and_add_imports(
original_code_helper, source_code=original_code_helper,
["blob"], function_names=["blob"],
optim_code, optimized_code=optim_code,
[], file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
set(), module_abspath=str(Path(__file__).resolve()),
preexisting_functions=[],
contextual_functions=set(),
project_root_path=str(Path(__file__).resolve().parent.resolve()),
) )
assert new_helper_code == expected_helper assert new_helper_code == expected_helper
@ -578,12 +582,15 @@ class CacheConfig(BaseConfig):
("CacheConfig", "__init__"), ("CacheConfig", "__init__"),
("CacheInitConfig", "__init__"), ("CacheInitConfig", "__init__"),
} }
new_code: str = replace_functions_in_file( new_code: str = replace_functions_and_add_imports(
original_code, source_code=original_code,
function_names, function_names=function_names,
optim_code, optimized_code=optim_code,
preexisting_functions, file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
contextual_functions, module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
) )
assert new_code == expected assert new_code == expected
@ -649,12 +656,15 @@ def test_test_libcst_code_replacement8() -> None:
"_hamming_distance", "_hamming_distance",
] ]
contextual_functions: set[tuple[str, str]] = set() contextual_functions: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file( new_code: str = replace_functions_and_add_imports(
original_code, source_code=original_code,
function_names, function_names=function_names,
optim_code, optimized_code=optim_code,
preexisting_functions, file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
contextual_functions, module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
) )
assert new_code == expected assert new_code == expected
@ -663,7 +673,7 @@ def test_test_libcst_code_replacement9() -> None:
optim_code = """import libcst as cst optim_code = """import libcst as cst
from typing import Optional from typing import Optional
def totally_new_function(value): def totally_new_function(value: Optional[str]):
return value return value
class NewClass: class NewClass:
@ -672,7 +682,7 @@ class NewClass:
def __call__(self, value): def __call__(self, value):
return self.name return self.name
def new_function2(value): def new_function2(value):
return value return cst.ensure_type(value, str)
""" """
original_code = """class NewClass: original_code = """class NewClass:
@ -685,15 +695,16 @@ print("Hello world")
""" """
expected = """import libcst as cst expected = """import libcst as cst
from typing import Optional from typing import Optional
class NewClass: class NewClass:
def __init__(self, name): def __init__(self, name):
self.name = str(name) self.name = str(name)
def __call__(self, value): def __call__(self, value):
return "I am still old" return "I am still old"
def new_function2(value): def new_function2(value):
return value return cst.ensure_type(value, str)
def totally_new_function(value): def totally_new_function(value: Optional[str]):
return value return value
print("Hello world") print("Hello world")
@ -705,12 +716,15 @@ print("Hello world")
("NewClass", "__init__"), ("NewClass", "__init__"),
("NewClass", "__call__"), ("NewClass", "__call__"),
} }
new_code: str = replace_functions_in_file( new_code: str = replace_functions_and_add_imports(
original_code, source_code=original_code,
[function_name], function_names=[function_name],
optim_code, optimized_code=optim_code,
preexisting_functions, file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
contextual_functions, module_abspath=str(Path(__file__).resolve()),
preexisting_functions=preexisting_functions,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
) )
assert new_code == expected assert new_code == expected
@ -764,11 +778,16 @@ class MainClass:
experiment_id=None, experiment_id=None,
), ),
) )
func_top_optimize = FunctionToOptimize(function_name="main_method", file_path=str(file_path), func_top_optimize = FunctionToOptimize(
parents=[FunctionParent("MainClass", "ClassDef")]) function_name="main_method",
file_path=str(file_path),
parents=[FunctionParent("MainClass", "ClassDef")],
)
with open(file_path) as f: with open(file_path) as f:
original_code = f.read() original_code = f.read()
code_context = opt.get_code_optimization_context(function_to_optimize=func_top_optimize, code_context = opt.get_code_optimization_context(
project_root=str(file_path.parent), function_to_optimize=func_top_optimize,
original_source_code=original_code).unwrap() project_root=str(file_path.parent),
original_source_code=original_code,
).unwrap()
assert code_context.code_to_optimize_with_helpers == get_code_output assert code_context.code_to_optimize_with_helpers == get_code_output