Integration and testing of multi-file diffs for every optimization case. Multi-file PR's TBD.

This commit is contained in:
renaud 2024-02-07 07:26:45 -08:00
parent 32c09ff6f6
commit c3f77597e6
3 changed files with 260 additions and 47 deletions

View file

@ -30,9 +30,13 @@ class OptimFunctionCollector(cst.CSTVisitor):
pass
if node.name.value == self.function_name:
self.optim_body = node
elif node.name.value not in self.preexisting_functions and (
isinstance(parent, cst.Module)
or (parent2 is not None and not isinstance(parent2, cst.ClassDef))
elif (
self.preexisting_functions
and node.name.value not in self.preexisting_functions
and (
isinstance(parent, cst.Module)
or (parent2 is not None and not isinstance(parent2, cst.ClassDef))
)
):
self.optim_new_functions.append(node)
@ -146,48 +150,51 @@ class OptimFunctionReplacer(cst.CSTTransformer):
# print(import_node)
# for name in import_node.names:
# asname = name.asname.name.value if name.asname else None
# AddImportsVisitor.add_needed_import(self.context, module =import_node.module.value, obj=name.name.value, asname=asname)
# AddImportsVisitor.add_needed_import(
# self.context, module =import_node.module.value, obj=name.name.value, asname=asname)
# #print(updated_node)
def replace_function_in_file(
def replace_functions_in_file(
source_code: str,
original_function_name: str,
original_function_names: list[str],
optimized_function: str,
preexisting_functions: list[str],
) -> str:
class_name = None
if original_function_name.count(".") == 0:
function_name = original_function_name
elif original_function_name.count(".") == 1:
class_name, function_name = original_function_name.split(".")
else:
raise ValueError(f"Don't know how to find {original_function_name} yet!")
visitor = OptimFunctionCollector(function_name, preexisting_functions)
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_function))
visited = module.visit(visitor)
for i, original_function_name in enumerate(original_function_names):
if original_function_name.count(".") == 0:
function_name = original_function_name
elif original_function_name.count(".") == 1:
class_name, function_name = original_function_name.split(".")
else:
raise ValueError(f"Don't know how to find {original_function_name} yet!")
visitor = OptimFunctionCollector(function_name, preexisting_functions)
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_function))
visited = module.visit(visitor)
if visitor.optim_body is None:
raise ValueError(
f"Did not find the function {function_name} in the optimized code"
if visitor.optim_body is None:
raise ValueError(
f"Did not find the function {function_name} in the optimized code"
)
optim_imports = [] if i > 0 else visitor.optim_imports
transformer = OptimFunctionReplacer(
visitor.function_name,
visitor.optim_body,
visitor.optim_new_class_functions,
optim_imports,
visitor.optim_new_functions,
class_name=class_name,
)
# TODO: If a dependency has been optimized and that dependency is imported in the file,
# then we need to update the original file where the dependency is imported from
transformer = OptimFunctionReplacer(
visitor.function_name,
visitor.optim_body,
visitor.optim_new_class_functions,
visitor.optim_imports,
visitor.optim_new_functions,
class_name=class_name,
)
original_module = cst.parse_module(source_code)
modified_tree = original_module.visit(transformer)
return modified_tree.code
original_module = cst.parse_module(source_code)
modified_tree = original_module.visit(transformer)
source_code = modified_tree.code
return source_code
def replace_function_definition_in_module(
function_name: str,
def replace_function_definitions_in_module(
function_names: list[str],
optimized_code: str,
module_abspath: str,
preexisting_functions: list[str],
@ -195,9 +202,9 @@ def replace_function_definition_in_module(
file: IO[str]
with open(module_abspath, "r") as file:
source_code: str = file.read()
new_code: str = replace_function_in_file(
new_code: str = replace_functions_in_file(
source_code,
function_name,
function_names,
optimized_code,
preexisting_functions,
)

View file

@ -3,6 +3,7 @@ import logging
import os
import pathlib
from argparse import ArgumentParser, SUPPRESS, Namespace
from collections import defaultdict
from typing import Tuple, Union
import libcst as cst
@ -13,7 +14,7 @@ from codeflash.cli_cmds.cli import process_cmd_args
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_extractor import get_code
from codeflash.code_utils.code_replacer import replace_function_definition_in_module
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.code_utils.code_utils import (
module_name_from_file_path,
get_all_function_names,
@ -188,6 +189,19 @@ class Optimizer:
) = get_constrained_function_context_and_dependent_functions(
function_to_optimize, self.args.project_root, code_to_optimize
)
preexisting_functions.extend(
[fn[0].full_name for fn in dependent_functions]
)
dependent_functions_by_module_abspath = defaultdict(set)
for dependent_fn, module_abspath in dependent_functions:
dependent_functions_by_module_abspath[module_abspath].add(
dependent_fn.full_name
)
original_dependent_code = {}
for module_abspath in dependent_functions_by_module_abspath.keys():
with open(module_abspath, "r") as f:
dependent_code = f.read()
original_dependent_code[module_abspath] = dependent_code
logging.info(
f"Code to be optimized:\n{code_to_optimize_with_dependents}"
)
@ -260,12 +274,22 @@ class Optimizer:
logging.info("Optimized Candidate:")
logging.info(optimized_code)
try:
replace_function_definition_in_module(
function_name,
replace_function_definitions_in_module(
[function_name],
optimized_code,
path,
preexisting_functions,
)
for (
module_abspath,
dependent_functions,
) in dependent_functions_by_module_abspath.items():
replace_function_definitions_in_module(
list(dependent_functions),
optimized_code,
module_abspath,
[],
)
except (
ValueError,
SyntaxError,
@ -274,6 +298,7 @@ class Optimizer:
) as e:
logging.error(e)
continue
(
success,
times_run,
@ -322,12 +347,23 @@ class Optimizer:
found_atleast_one_optimization = True
logging.info(f"BEST OPTIMIZED CODE\n{best_optimization[0]}")
replace_function_definition_in_module(
function_name,
best_optimization[0],
optimized_code = best_optimization[0]
replace_function_definitions_in_module(
[function_name],
optimized_code,
path,
preexisting_functions,
)
for (
module_abspath,
dependent_functions,
) in dependent_functions_by_module_abspath.items():
replace_function_definitions_in_module(
list(dependent_functions),
optimized_code,
module_abspath,
[],
)
explanation_final = Explanation(
raw_explanation_message=best_optimization[1],
winning_test_results=winning_test_results,
@ -341,6 +377,10 @@ class Optimizer:
)
new_code = lint_code(path)
new_dependent_code = [
lint_code(module_abspath)
for module_abspath in dependent_functions_by_module_abspath.keys()
]
logging.info(
f"Optimization was validated for correctness by running the following test - "
@ -352,9 +392,6 @@ class Optimizer:
)
logging.info(f"📈 {explanation_final.perf_improvement_line}")
# TODO: Create multi-file PR for dependent functions, extract dependent functions from new code
# and overwrite original definitions, with replace_function_in_file. Also need to lint edited
# files.
check_create_pr(
optimize_all=self.args.all,
path=path,

View file

@ -1,6 +1,6 @@
import os
from codeflash.code_utils.code_replacer import replace_function_in_file
from codeflash.code_utils.code_replacer import replace_functions_in_file
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
@ -47,7 +47,176 @@ print("Hello world")
function_name: str = "NewClass.new_function"
preexisting_functions: list[str] = ["NewClass.new_function"]
new_code: str = replace_function_in_file(
original_code, function_name, optim_code, preexisting_functions
new_code: str = replace_functions_in_file(
original_code, [function_name], optim_code, preexisting_functions
)
assert new_code == expected
def test_test_libcst_code_replacement2():
optim_code = """import libcst as cst
from typing import Optional
def totally_new_function(value):
return value
def other_function(st):
return(st * 2)
class NewClass:
def __init__(self, name):
self.name = name
def new_function(self, value):
return other_function(self.name)
def new_function2(value):
return value
"""
original_code = """from OtherModule import other_function
class NewClass:
def __init__(self, name):
self.name = name
def new_function(self, value):
return other_function("I am still old")
print("Hello world")
"""
expected = """import libcst as cst
from typing import Optional
from OtherModule import other_function
class NewClass:
def __init__(self, name):
self.name = name
def new_function(self, value):
return other_function(self.name)
def new_function2(value):
return value
def totally_new_function(value):
return value
print("Hello world")
"""
function_name: str = "NewClass.new_function"
preexisting_functions: list[str] = ["NewClass.new_function", "other_function"]
new_code: str = replace_functions_in_file(
original_code, [function_name], optim_code, preexisting_functions
)
assert new_code == expected
def test_test_libcst_code_replacement3():
optim_code = """import libcst as cst
from typing import Optional
def totally_new_function(value):
return value
def other_function(st):
return(st * 2)
class NewClass:
def __init__(self, name):
self.name = name
def new_function(self, value):
return other_function(self.name)
def new_function2(value):
return value
"""
original_code = """import libcst as cst
from typing import Mandatory
print("Au revoir")
def yet_another_function(values):
return len(values)
def other_function(st):
return(st + st)
print("Salut monde")
"""
expected = """import libcst as cst
from typing import Optional
import libcst as cst
from typing import Mandatory
print("Au revoir")
def yet_another_function(values):
return len(values)
def other_function(st):
return(st * 2)
print("Salut monde")
"""
function_names: list[str] = ["other_function"]
preexisting_functions: list[str] = []
new_code: str = replace_functions_in_file(
original_code, function_names, optim_code, preexisting_functions
)
assert new_code == expected
def test_test_libcst_code_replacement4():
optim_code = """import libcst as cst
from typing import Optional
def totally_new_function(value):
return value
def yet_another_function(values):
return len(values) + 2
def other_function(st):
return(st * 2)
class NewClass:
def __init__(self, name):
self.name = name
def new_function(self, value):
return other_function(self.name)
def new_function2(value):
return value
"""
original_code = """import libcst as cst
from typing import Mandatory
print("Au revoir")
def yet_another_function(values):
return len(values)
def other_function(st):
return(st + st)
print("Salut monde")
"""
expected = """import libcst as cst
from typing import Optional
import libcst as cst
from typing import Mandatory
print("Au revoir")
def yet_another_function(values):
return len(values) + 2
def other_function(st):
return(st * 2)
print("Salut monde")
"""
function_names: list[str] = ["yet_another_function", "other_function"]
preexisting_functions: list[str] = []
new_code: str = replace_functions_in_file(
original_code, function_names, optim_code, preexisting_functions
)
assert new_code == expected