mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Integration and testing of multi-file diffs for every optimization case. Multi-file PR's TBD.
This commit is contained in:
parent
32c09ff6f6
commit
c3f77597e6
3 changed files with 260 additions and 47 deletions
|
|
@ -30,9 +30,13 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
pass
|
||||
if node.name.value == self.function_name:
|
||||
self.optim_body = node
|
||||
elif node.name.value not in self.preexisting_functions and (
|
||||
isinstance(parent, cst.Module)
|
||||
or (parent2 is not None and not isinstance(parent2, cst.ClassDef))
|
||||
elif (
|
||||
self.preexisting_functions
|
||||
and node.name.value not in self.preexisting_functions
|
||||
and (
|
||||
isinstance(parent, cst.Module)
|
||||
or (parent2 is not None and not isinstance(parent2, cst.ClassDef))
|
||||
)
|
||||
):
|
||||
self.optim_new_functions.append(node)
|
||||
|
||||
|
|
@ -146,48 +150,51 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
# print(import_node)
|
||||
# for name in import_node.names:
|
||||
# asname = name.asname.name.value if name.asname else None
|
||||
# AddImportsVisitor.add_needed_import(self.context, module =import_node.module.value, obj=name.name.value, asname=asname)
|
||||
# AddImportsVisitor.add_needed_import(
|
||||
# self.context, module =import_node.module.value, obj=name.name.value, asname=asname)
|
||||
# #print(updated_node)
|
||||
|
||||
|
||||
def replace_function_in_file(
|
||||
def replace_functions_in_file(
|
||||
source_code: str,
|
||||
original_function_name: str,
|
||||
original_function_names: list[str],
|
||||
optimized_function: str,
|
||||
preexisting_functions: list[str],
|
||||
) -> str:
|
||||
class_name = None
|
||||
if original_function_name.count(".") == 0:
|
||||
function_name = original_function_name
|
||||
elif original_function_name.count(".") == 1:
|
||||
class_name, function_name = original_function_name.split(".")
|
||||
else:
|
||||
raise ValueError(f"Don't know how to find {original_function_name} yet!")
|
||||
visitor = OptimFunctionCollector(function_name, preexisting_functions)
|
||||
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_function))
|
||||
visited = module.visit(visitor)
|
||||
for i, original_function_name in enumerate(original_function_names):
|
||||
if original_function_name.count(".") == 0:
|
||||
function_name = original_function_name
|
||||
elif original_function_name.count(".") == 1:
|
||||
class_name, function_name = original_function_name.split(".")
|
||||
else:
|
||||
raise ValueError(f"Don't know how to find {original_function_name} yet!")
|
||||
visitor = OptimFunctionCollector(function_name, preexisting_functions)
|
||||
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_function))
|
||||
visited = module.visit(visitor)
|
||||
|
||||
if visitor.optim_body is None:
|
||||
raise ValueError(
|
||||
f"Did not find the function {function_name} in the optimized code"
|
||||
if visitor.optim_body is None:
|
||||
raise ValueError(
|
||||
f"Did not find the function {function_name} in the optimized code"
|
||||
)
|
||||
optim_imports = [] if i > 0 else visitor.optim_imports
|
||||
|
||||
transformer = OptimFunctionReplacer(
|
||||
visitor.function_name,
|
||||
visitor.optim_body,
|
||||
visitor.optim_new_class_functions,
|
||||
optim_imports,
|
||||
visitor.optim_new_functions,
|
||||
class_name=class_name,
|
||||
)
|
||||
# TODO: If a dependency has been optimized and that dependency is imported in the file,
|
||||
# then we need to update the original file where the dependency is imported from
|
||||
transformer = OptimFunctionReplacer(
|
||||
visitor.function_name,
|
||||
visitor.optim_body,
|
||||
visitor.optim_new_class_functions,
|
||||
visitor.optim_imports,
|
||||
visitor.optim_new_functions,
|
||||
class_name=class_name,
|
||||
)
|
||||
original_module = cst.parse_module(source_code)
|
||||
modified_tree = original_module.visit(transformer)
|
||||
return modified_tree.code
|
||||
original_module = cst.parse_module(source_code)
|
||||
modified_tree = original_module.visit(transformer)
|
||||
source_code = modified_tree.code
|
||||
return source_code
|
||||
|
||||
|
||||
def replace_function_definition_in_module(
|
||||
function_name: str,
|
||||
def replace_function_definitions_in_module(
|
||||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
module_abspath: str,
|
||||
preexisting_functions: list[str],
|
||||
|
|
@ -195,9 +202,9 @@ def replace_function_definition_in_module(
|
|||
file: IO[str]
|
||||
with open(module_abspath, "r") as file:
|
||||
source_code: str = file.read()
|
||||
new_code: str = replace_function_in_file(
|
||||
new_code: str = replace_functions_in_file(
|
||||
source_code,
|
||||
function_name,
|
||||
function_names,
|
||||
optimized_code,
|
||||
preexisting_functions,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import logging
|
|||
import os
|
||||
import pathlib
|
||||
from argparse import ArgumentParser, SUPPRESS, Namespace
|
||||
from collections import defaultdict
|
||||
from typing import Tuple, Union
|
||||
|
||||
import libcst as cst
|
||||
|
|
@ -13,7 +14,7 @@ from codeflash.cli_cmds.cli import process_cmd_args
|
|||
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_extractor import get_code
|
||||
from codeflash.code_utils.code_replacer import replace_function_definition_in_module
|
||||
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
|
||||
from codeflash.code_utils.code_utils import (
|
||||
module_name_from_file_path,
|
||||
get_all_function_names,
|
||||
|
|
@ -188,6 +189,19 @@ class Optimizer:
|
|||
) = get_constrained_function_context_and_dependent_functions(
|
||||
function_to_optimize, self.args.project_root, code_to_optimize
|
||||
)
|
||||
preexisting_functions.extend(
|
||||
[fn[0].full_name for fn in dependent_functions]
|
||||
)
|
||||
dependent_functions_by_module_abspath = defaultdict(set)
|
||||
for dependent_fn, module_abspath in dependent_functions:
|
||||
dependent_functions_by_module_abspath[module_abspath].add(
|
||||
dependent_fn.full_name
|
||||
)
|
||||
original_dependent_code = {}
|
||||
for module_abspath in dependent_functions_by_module_abspath.keys():
|
||||
with open(module_abspath, "r") as f:
|
||||
dependent_code = f.read()
|
||||
original_dependent_code[module_abspath] = dependent_code
|
||||
logging.info(
|
||||
f"Code to be optimized:\n{code_to_optimize_with_dependents}"
|
||||
)
|
||||
|
|
@ -260,12 +274,22 @@ class Optimizer:
|
|||
logging.info("Optimized Candidate:")
|
||||
logging.info(optimized_code)
|
||||
try:
|
||||
replace_function_definition_in_module(
|
||||
function_name,
|
||||
replace_function_definitions_in_module(
|
||||
[function_name],
|
||||
optimized_code,
|
||||
path,
|
||||
preexisting_functions,
|
||||
)
|
||||
for (
|
||||
module_abspath,
|
||||
dependent_functions,
|
||||
) in dependent_functions_by_module_abspath.items():
|
||||
replace_function_definitions_in_module(
|
||||
list(dependent_functions),
|
||||
optimized_code,
|
||||
module_abspath,
|
||||
[],
|
||||
)
|
||||
except (
|
||||
ValueError,
|
||||
SyntaxError,
|
||||
|
|
@ -274,6 +298,7 @@ class Optimizer:
|
|||
) as e:
|
||||
logging.error(e)
|
||||
continue
|
||||
|
||||
(
|
||||
success,
|
||||
times_run,
|
||||
|
|
@ -322,12 +347,23 @@ class Optimizer:
|
|||
found_atleast_one_optimization = True
|
||||
logging.info(f"BEST OPTIMIZED CODE\n{best_optimization[0]}")
|
||||
|
||||
replace_function_definition_in_module(
|
||||
function_name,
|
||||
best_optimization[0],
|
||||
optimized_code = best_optimization[0]
|
||||
replace_function_definitions_in_module(
|
||||
[function_name],
|
||||
optimized_code,
|
||||
path,
|
||||
preexisting_functions,
|
||||
)
|
||||
for (
|
||||
module_abspath,
|
||||
dependent_functions,
|
||||
) in dependent_functions_by_module_abspath.items():
|
||||
replace_function_definitions_in_module(
|
||||
list(dependent_functions),
|
||||
optimized_code,
|
||||
module_abspath,
|
||||
[],
|
||||
)
|
||||
explanation_final = Explanation(
|
||||
raw_explanation_message=best_optimization[1],
|
||||
winning_test_results=winning_test_results,
|
||||
|
|
@ -341,6 +377,10 @@ class Optimizer:
|
|||
)
|
||||
|
||||
new_code = lint_code(path)
|
||||
new_dependent_code = [
|
||||
lint_code(module_abspath)
|
||||
for module_abspath in dependent_functions_by_module_abspath.keys()
|
||||
]
|
||||
|
||||
logging.info(
|
||||
f"Optimization was validated for correctness by running the following test - "
|
||||
|
|
@ -352,9 +392,6 @@ class Optimizer:
|
|||
)
|
||||
logging.info(f"📈 {explanation_final.perf_improvement_line}")
|
||||
|
||||
# TODO: Create multi-file PR for dependent functions, extract dependent functions from new code
|
||||
# and overwrite original definitions, with replace_function_in_file. Also need to lint edited
|
||||
# files.
|
||||
check_create_pr(
|
||||
optimize_all=self.args.all,
|
||||
path=path,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
|
||||
from codeflash.code_utils.code_replacer import replace_function_in_file
|
||||
from codeflash.code_utils.code_replacer import replace_functions_in_file
|
||||
|
||||
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
|
||||
|
||||
|
|
@ -47,7 +47,176 @@ print("Hello world")
|
|||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_functions: list[str] = ["NewClass.new_function"]
|
||||
new_code: str = replace_function_in_file(
|
||||
original_code, function_name, optim_code, preexisting_functions
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code, [function_name], optim_code, preexisting_functions
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement2():
|
||||
optim_code = """import libcst as cst
|
||||
from typing import Optional
|
||||
|
||||
def totally_new_function(value):
|
||||
return value
|
||||
|
||||
def other_function(st):
|
||||
return(st * 2)
|
||||
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def new_function(self, value):
|
||||
return other_function(self.name)
|
||||
def new_function2(value):
|
||||
return value
|
||||
"""
|
||||
|
||||
original_code = """from OtherModule import other_function
|
||||
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def new_function(self, value):
|
||||
return other_function("I am still old")
|
||||
|
||||
print("Hello world")
|
||||
"""
|
||||
expected = """import libcst as cst
|
||||
from typing import Optional
|
||||
from OtherModule import other_function
|
||||
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def new_function(self, value):
|
||||
return other_function(self.name)
|
||||
def new_function2(value):
|
||||
return value
|
||||
|
||||
def totally_new_function(value):
|
||||
return value
|
||||
|
||||
print("Hello world")
|
||||
"""
|
||||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_functions: list[str] = ["NewClass.new_function", "other_function"]
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code, [function_name], optim_code, preexisting_functions
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement3():
|
||||
optim_code = """import libcst as cst
|
||||
from typing import Optional
|
||||
|
||||
def totally_new_function(value):
|
||||
return value
|
||||
|
||||
def other_function(st):
|
||||
return(st * 2)
|
||||
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def new_function(self, value):
|
||||
return other_function(self.name)
|
||||
def new_function2(value):
|
||||
return value
|
||||
"""
|
||||
|
||||
original_code = """import libcst as cst
|
||||
from typing import Mandatory
|
||||
|
||||
print("Au revoir")
|
||||
|
||||
def yet_another_function(values):
|
||||
return len(values)
|
||||
|
||||
def other_function(st):
|
||||
return(st + st)
|
||||
|
||||
print("Salut monde")
|
||||
"""
|
||||
expected = """import libcst as cst
|
||||
from typing import Optional
|
||||
import libcst as cst
|
||||
from typing import Mandatory
|
||||
|
||||
print("Au revoir")
|
||||
|
||||
def yet_another_function(values):
|
||||
return len(values)
|
||||
|
||||
def other_function(st):
|
||||
return(st * 2)
|
||||
|
||||
print("Salut monde")
|
||||
"""
|
||||
|
||||
function_names: list[str] = ["other_function"]
|
||||
preexisting_functions: list[str] = []
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code, function_names, optim_code, preexisting_functions
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
||||
def test_test_libcst_code_replacement4():
|
||||
optim_code = """import libcst as cst
|
||||
from typing import Optional
|
||||
|
||||
def totally_new_function(value):
|
||||
return value
|
||||
|
||||
def yet_another_function(values):
|
||||
return len(values) + 2
|
||||
|
||||
def other_function(st):
|
||||
return(st * 2)
|
||||
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def new_function(self, value):
|
||||
return other_function(self.name)
|
||||
def new_function2(value):
|
||||
return value
|
||||
"""
|
||||
|
||||
original_code = """import libcst as cst
|
||||
from typing import Mandatory
|
||||
|
||||
print("Au revoir")
|
||||
|
||||
def yet_another_function(values):
|
||||
return len(values)
|
||||
|
||||
def other_function(st):
|
||||
return(st + st)
|
||||
|
||||
print("Salut monde")
|
||||
"""
|
||||
expected = """import libcst as cst
|
||||
from typing import Optional
|
||||
import libcst as cst
|
||||
from typing import Mandatory
|
||||
|
||||
print("Au revoir")
|
||||
def yet_another_function(values):
|
||||
return len(values) + 2
|
||||
|
||||
def other_function(st):
|
||||
return(st * 2)
|
||||
|
||||
print("Salut monde")
|
||||
"""
|
||||
|
||||
function_names: list[str] = ["yet_another_function", "other_function"]
|
||||
preexisting_functions: list[str] = []
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code, function_names, optim_code, preexisting_functions
|
||||
)
|
||||
assert new_code == expected
|
||||
|
|
|
|||
Loading…
Reference in a new issue