WIP, blocked by debugger from fixing last failing test (bubble sort in class).

This commit is contained in:
renaud 2024-04-09 05:36:43 -07:00
parent 0baa7f0b8a
commit b092d1b84f
8 changed files with 194 additions and 94 deletions

View file

@ -3,26 +3,32 @@ from __future__ import annotations
import ast
import logging
from collections import deque
from typing import Optional
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
def get_code(functions_to_optimize: list[FunctionToOptimize]) -> Optional[str]:
"""Returns the code for a class or functions in a file."""
def get_code(
functions_to_optimize: list[FunctionToOptimize],
) -> tuple[str | None, set[tuple[str, str]]]:
"""Return the code for a class or functions in a file."""
file_path: str = functions_to_optimize[0].file_path
class_skeleton: set[tuple[int, int]] = set()
class_skeleton: set[tuple[int, int | None]] = set()
contextual_dunder_methods: set[tuple[str, str]] = set()
target_code: str = ""
def find_target(
node_list: list[ast.stmt],
name_parts: tuple[str, str] | tuple[str],
) -> Optional[ast.AST]:
target: Optional[
ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Assign | ast.AnnAssign
] = None
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Assign | ast.AnnAssign
) -> ast.AST | None:
target: (
ast.FunctionDef
| ast.AsyncFunctionDef
| ast.ClassDef
| ast.Assign
| ast.AnnAssign
| None
) = None
node: ast.stmt
for node in node_list:
if (
# The many mypy issues will be fixed once this code moves to the backend,
@ -57,17 +63,19 @@ def get_code(functions_to_optimize: list[FunctionToOptimize]) -> Optional[str]:
if isinstance(cbody[0], ast.expr): # Is a docstring
class_skeleton.add((cbody[0].lineno, cbody[0].end_lineno))
cbody = cbody[1:]
cnode: ast.FunctionDef | ast.AsyncFunctionDef
cnode: ast.stmt
for cnode in cbody:
# Collect all dunder methods.
cnode_name: str
if (
isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef))
and len(cnode_name := cnode.name) > 4
and cnode_name != name_parts[1]
and cnode_name.isascii()
and cnode_name.startswith("__")
and cnode_name.endswith("__")
):
contextual_dunder_methods.add((target.name, cnode_name))
class_skeleton.add((cnode.lineno, cnode.end_lineno))
return find_target(target.body, name_parts[1:])
@ -80,7 +88,7 @@ def get_code(functions_to_optimize: list[FunctionToOptimize]) -> Optional[str]:
module_node = ast.parse(source_code)
except SyntaxError as e:
logging.exception(f"get_code - Syntax error in code: {e}")
return None
return None, set()
# Get the source code lines for the target node
lines = source_code.splitlines(keepends=True)
if len(functions_to_optimize[0].parents) == 1:
@ -95,7 +103,7 @@ def get_code(functions_to_optimize: list[FunctionToOptimize]) -> Optional[str]:
logging.error(
f"Error: get_code does not support nesting function in functions: {functions_to_optimize[0].parents}",
)
return None
return None, set()
elif len(functions_to_optimize[0].parents) == 0:
qualified_name_parts_list = [(functions_to_optimize[0].function_name,)]
else:
@ -103,9 +111,9 @@ def get_code(functions_to_optimize: list[FunctionToOptimize]) -> Optional[str]:
"Error: get_code does not support more than one level of nesting for now. "
f"Parents: {functions_to_optimize[0].parents}",
)
return None
return None, set()
for qualified_name_parts in qualified_name_parts_list:
target_node: Optional[ast.AST] = find_target(module_node.body, qualified_name_parts)
target_node: ast.AST | None = find_target(module_node.body, qualified_name_parts)
if target_node is None:
continue
@ -118,15 +126,15 @@ def get_code(functions_to_optimize: list[FunctionToOptimize]) -> Optional[str]:
)
else:
target_code += "".join(lines[target_node.lineno - 1 : target_node.end_lineno])
class_list: list[tuple[int, int]] = sorted(list(class_skeleton))
class_list: list[tuple[int, int | None]] = sorted(class_skeleton)
class_code = "".join(
["".join(lines[s_lineno - 1 : e_lineno]) for (s_lineno, e_lineno) in class_list],
)
return class_code + target_code
return class_code + target_code, contextual_dunder_methods
def get_code_no_skeleton(file_path: str, target_name: str) -> Optional[str]:
"""Returns the code for a function in a file. Irrespective of class skeleton."""
def get_code_no_skeleton(file_path: str, target_name: str) -> str | None:
"""Return the code for a function in a file, irrespective of class skeleton."""
with open(file_path, encoding="utf8") as file:
source_code = file.read()
@ -165,15 +173,17 @@ def get_code_no_skeleton(file_path: str, target_name: str) -> Optional[str]:
return target_code
def extract_code(functions_to_optimize: list[FunctionToOptimize]) -> Optional[str]:
edited_code: Optional[str] = get_code(functions_to_optimize)
def extract_code(
functions_to_optimize: list[FunctionToOptimize],
) -> tuple[str | None, set[tuple[str, str]]]:
edited_code, contextual_dunder_methods = get_code(functions_to_optimize)
if edited_code is None:
return None
return None, set()
try:
compile(edited_code, "edited_code", "exec")
except SyntaxError as e:
logging.exception(
f"extract_code - Syntax error in extracted optimization candidate code: {e}",
)
return None
return edited_code
return None, set()
return edited_code, contextual_dunder_methods

View file

@ -1,24 +1,34 @@
from typing import List, Union, Optional, IO
from __future__ import annotations
from typing import IO
import libcst as cst
from libcst import SimpleStatementLine, FunctionDef
from libcst import FunctionDef
class OptimFunctionCollector(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
def __init__(self, function_name: str, preexisting_functions: Optional[List[str]] = None):
def __init__(
self,
function_name: str,
class_name: str | None,
immutable_methods: set[tuple[str, str]],
preexisting_functions: list[str] | None = None,
) -> None:
super().__init__()
if preexisting_functions is None:
preexisting_functions = []
self.function_name = function_name
self.optim_body: Union[FunctionDef, None] = None
self.class_name = class_name
self.optim_body: FunctionDef | None = None
self.optim_new_class_functions: list[cst.FunctionDef] = []
self.optim_new_functions: list[cst.FunctionDef] = []
self.optim_imports: List[Union[cst.Import, cst.ImportFrom]] = []
self.optim_imports: list[cst.SimpleStatementLine] = []
self.preexisting_functions = preexisting_functions
self.immutable_methods = immutable_methods.union({(self.class_name, self.function_name)})
def visit_FunctionDef(self, node: cst.FunctionDef):
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
parent2 = None
try:
@ -39,17 +49,19 @@ class OptimFunctionCollector(cst.CSTVisitor):
self.optim_new_functions.append(node)
def visit_ClassDef_body(self, node: cst.ClassDef) -> None:
for class_node in node.body.body:
if isinstance(class_node, cst.FunctionDef) and class_node.name.value not in [
"__init__",
self.function_name,
]:
self.optim_new_class_functions.append(class_node)
for child_node in node.body.body:
if (
isinstance(child_node, cst.FunctionDef)
and (
node.name.value,
child_node.name.value,
)
not in self.immutable_methods
):
self.optim_new_class_functions.append(child_node)
def leave_SimpleStatementLine(self, original_node: "SimpleStatementLine") -> None:
if isinstance(original_node.body[0], cst.Import):
self.optim_imports.append(original_node)
elif isinstance(original_node.body[0], cst.ImportFrom):
def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine) -> None:
if isinstance(original_node.body[0], (cst.Import, cst.ImportFrom)):
self.optim_imports.append(original_node)
@ -58,11 +70,11 @@ class OptimFunctionReplacer(cst.CSTTransformer):
self,
function_name: str,
optim_body: cst.FunctionDef,
optim_new_class_functions: List[cst.FunctionDef],
optim_imports: List[Union[cst.Import, cst.ImportFrom]],
optim_new_functions,
class_name=None,
):
optim_new_class_functions: list[cst.FunctionDef],
optim_imports: list[cst.SimpleStatementLine],
optim_new_functions: list[cst.FunctionDef],
class_name: str | None = None,
) -> None:
super().__init__()
self.function_name = function_name
self.optim_body = optim_body
@ -77,7 +89,9 @@ class OptimFunctionReplacer(cst.CSTTransformer):
return False
def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
self,
original_node: cst.FunctionDef,
updated_node: cst.FunctionDef,
) -> cst.FunctionDef:
if original_node.name.value == self.function_name and (
self.depth == 0 or (self.depth == 1 and self.in_class)
@ -93,7 +107,9 @@ class OptimFunctionReplacer(cst.CSTTransformer):
return self.in_class
def leave_ClassDef(
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
self,
original_node: cst.ClassDef,
updated_node: cst.ClassDef,
) -> cst.ClassDef:
self.depth -= 1
if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name):
@ -101,7 +117,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
return updated_node.with_changes(
body=updated_node.body.with_changes(
body=(list(updated_node.body.body) + self.optim_new_class_functions),
)
),
)
return updated_node
@ -123,7 +139,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
*node.body[: max_function_index + 1],
*self.optim_new_functions,
*node.body[max_function_index + 1 :],
)
),
)
elif class_index is not None:
node = node.with_changes(
@ -131,7 +147,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
*node.body[: class_index + 1],
*self.optim_new_functions,
*node.body[class_index + 1 :],
)
),
)
else:
node = node.with_changes(body=(*self.optim_new_functions, *node.body))
@ -165,6 +181,7 @@ def replace_functions_in_file(
original_function_names: list[str],
optimized_code: str,
preexisting_functions: list[str],
immutable_methods: set[tuple[str, str]],
) -> str:
parsed_function_names = []
for original_function_name in original_function_names:
@ -179,14 +196,19 @@ def replace_functions_in_file(
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
for i, (function_name, class_name) in enumerate(parsed_function_names):
visitor = OptimFunctionCollector(function_name, preexisting_functions)
visited = module.visit(visitor)
visitor = OptimFunctionCollector(
function_name,
class_name,
immutable_methods,
preexisting_functions,
)
module.visit(visitor)
if visitor.optim_body is None and not preexisting_functions:
continue
elif visitor.optim_body is None:
if visitor.optim_body is None:
raise ValueError(f"Did not find the function {function_name} in the optimized code")
optim_imports = [] if i > 0 else visitor.optim_imports
optim_imports: list[cst.SimpleStatementLine] = [] if i > 0 else visitor.optim_imports
transformer = OptimFunctionReplacer(
visitor.function_name,
@ -208,15 +230,17 @@ def replace_function_definitions_in_module(
optimized_code: str,
module_abspath: str,
preexisting_functions: list[str],
immutable_methods: set[tuple[str, str]],
) -> None:
file: IO[str]
with open(module_abspath, "r", encoding="utf8") as file:
with open(module_abspath, encoding="utf8") as file:
source_code: str = file.read()
new_code: str = replace_functions_in_file(
source_code,
function_names,
optimized_code,
preexisting_functions,
immutable_methods,
)
with open(module_abspath, "w", encoding="utf8") as file:
file.write(new_code)

View file

@ -2,6 +2,8 @@
solved problem, please reach out to us at careers@codeflash.ai. We're hiring!
"""
from __future__ import annotations
import concurrent.futures
import logging
import os
@ -9,7 +11,7 @@ import pathlib
import uuid
from argparse import SUPPRESS, ArgumentParser, Namespace
from collections import defaultdict
from typing import IO, Tuple, Union
from typing import IO, Dict, Tuple, Union
import libcst as cst
@ -212,7 +214,9 @@ class Optimizer:
pathlib.Path(get_run_tmp_file("test_return_values_0.sqlite")).unlink(
missing_ok=True,
)
code_to_optimize = extract_code([function_to_optimize])
code_to_optimize, contextual_dunder_methods = extract_code(
[function_to_optimize],
)
if code_to_optimize is None:
logging.error("Could not find function to optimize.")
continue
@ -247,18 +251,13 @@ class Optimizer:
for df in dependent_methods
]
if len(optimizable_methods) > 1:
code_to_optimize_with_dependents = (
dependent_code + "\n" + extract_code(optimizable_methods)
code_to_optimize, contextual_dunder_methods = extract_code(
optimizable_methods,
)
else:
code_to_optimize_with_dependents = (
dependent_code + "\n" + code_to_optimize
)
else:
code_to_optimize_with_dependents = dependent_code + "\n" + code_to_optimize
if code_to_optimize_with_dependents is None:
logging.error("Could not find function with dependents to optimize.")
continue
if code_to_optimize is None:
logging.error("Could not find function to optimize.")
continue
code_to_optimize_with_dependents = dependent_code + "\n" + code_to_optimize
preexisting_functions.extend(
[fn[0].full_name.split(".")[-1] for fn in dependent_functions],
)
@ -323,7 +322,7 @@ class Optimizer:
# TODO: Postprocess the optimized function to include the original docstring and such
best_optimization = []
speedup_ratios = dict()
speedup_ratios: Dict[str, float | None] = dict()
optimized_runtimes = dict()
is_correct = dict()
@ -346,6 +345,7 @@ class Optimizer:
optimization.source_code,
path,
preexisting_functions,
contextual_dunder_methods,
)
for (
module_abspath,
@ -356,6 +356,7 @@ class Optimizer:
optimization.source_code,
module_abspath,
[],
contextual_dunder_methods,
)
except (
ValueError,
@ -445,6 +446,7 @@ class Optimizer:
optimized_code,
path,
preexisting_functions,
contextual_dunder_methods,
)
for (
module_abspath,
@ -455,6 +457,7 @@ class Optimizer:
optimized_code,
module_abspath,
[],
contextual_dunder_methods,
)
explanation_final = Explanation(
raw_explanation_message=best_optimization[1],

View file

@ -107,7 +107,7 @@ def get_type_annotation_context(
node_parents[:-1],
),
],
)
)[0]
if source_code:
sources.append(
(

View file

@ -76,6 +76,7 @@ warn_required_dynamic_aliases = true
[tool.ruff.lint]
select = ["ALL"]
unfixable = ["F401"]
ignore = ["ANN101"]
[tool.ruff.lint.flake8-type-checking]
strict = true

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import os
from codeflash.code_utils.code_replacer import replace_functions_in_file
@ -47,11 +49,13 @@ print("Hello world")
function_name: str = "NewClass.new_function"
preexisting_functions: list[str] = ["new_function"]
immutable_methods: set[tuple[str, str]] = {("NewClass", "__init__")}
new_code: str = replace_functions_in_file(
original_code,
[function_name],
optim_code,
preexisting_functions,
immutable_methods,
)
assert new_code == expected
@ -65,7 +69,7 @@ def totally_new_function(value):
def other_function(st):
return(st * 2)
class NewClass:
def __init__(self, name):
self.name = name
@ -105,8 +109,13 @@ print("Hello world")
function_name: str = "NewClass.new_function"
preexisting_functions: list[str] = ["new_function", "other_function"]
immutable_methods: set[tuple[str, str]] = {("NewClass", "__init__")}
new_code: str = replace_functions_in_file(
original_code, [function_name], optim_code, preexisting_functions
original_code,
[function_name],
optim_code,
preexisting_functions,
immutable_methods,
)
assert new_code == expected
@ -161,8 +170,13 @@ print("Salut monde")
function_names: list[str] = ["module.other_function"]
preexisting_functions: list[str] = []
immutable_methods: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
original_code, function_names, optim_code, preexisting_functions
original_code,
function_names,
optim_code,
preexisting_functions,
immutable_methods,
)
assert new_code == expected
@ -173,7 +187,7 @@ from typing import Optional
def totally_new_function(value):
return value
def yet_another_function(values):
return len(values) + 2
@ -208,6 +222,7 @@ import libcst as cst
from typing import Mandatory
print("Au revoir")
def yet_another_function(values):
return len(values) + 2
@ -219,8 +234,13 @@ print("Salut monde")
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
preexisting_functions: list[str] = []
immutable_methods: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
original_code, function_names, optim_code, preexisting_functions
original_code,
function_names,
optim_code,
preexisting_functions,
immutable_methods,
)
assert new_code == expected
@ -232,7 +252,7 @@ def test_test_libcst_code_replacement5() -> None:
def badsort(ploc):
donothing(ploc)
def supersort(doink):
for i in range(len(doink)):
fix(doink, i)
@ -256,7 +276,7 @@ def sorter_deps(arr):
def badsort(ploc):
donothing(ploc)
def supersort(doink):
for i in range(len(doink)):
fix(doink, i)
@ -264,8 +284,13 @@ def supersort(doink):
function_names: list[str] = ["sorter_deps"]
preexisting_functions: list[str] = ["sorter_deps"]
immutable_methods: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
original_code, function_names, optim_code, preexisting_functions
original_code,
function_names,
optim_code,
preexisting_functions,
immutable_methods,
)
assert new_code == expected
@ -343,11 +368,16 @@ print("Not cool")
["other_function"],
optim_code,
["other_function", "yet_another_function", "blob"],
set(),
)
assert new_main_code == expected_main
new_dependent_code: str = replace_functions_in_file(
original_code_dependent, ["blob"], optim_code, []
original_code_dependent,
["blob"],
optim_code,
[],
set(),
)
assert new_dependent_code == expected_dependent
@ -537,8 +567,17 @@ class CacheConfig(BaseConfig):
"__init__",
"from_config",
]
immutable_methods: set[tuple[str, str]] = {
("CacheSimilarityEvalConfig", "__init__"),
("CacheConfig", "__init__"),
("CacheInitConfig", "__init__"),
}
new_code: str = replace_functions_in_file(
original_code, function_names, optim_code, preexisting_functions
original_code,
function_names,
optim_code,
preexisting_functions,
immutable_methods,
)
assert new_code == expected
@ -603,7 +642,12 @@ def test_test_libcst_code_replacement8() -> None:
preexisting_functions: list[str] = [
"_hamming_distance",
]
immutable_methods: set[tuple[str, str]] = set()
new_code: str = replace_functions_in_file(
original_code, function_names, optim_code, preexisting_functions
original_code,
function_names,
optim_code,
preexisting_functions,
immutable_methods,
)
assert new_code == expected

View file

@ -12,8 +12,9 @@ def test_get_code_function() -> None:
f.write(code)
f.flush()
new_code = get_code([FunctionToOptimize("test", f.name, [])])
new_code, contextual_dunder_methods = get_code([FunctionToOptimize("test", f.name, [])])
assert new_code == code
assert contextual_dunder_methods == set()
def test_get_code_property() -> None:
@ -27,12 +28,13 @@ def test_get_code_property() -> None:
f.write(code)
f.flush()
new_code = get_code(
new_code, contextual_dunder_methods = get_code(
[
FunctionToOptimize("test", f.name, [FunctionParent("TestClass", "ClassDef")]),
],
)
assert new_code == code
assert contextual_dunder_methods == {("TestClass", "__init__")}
def test_get_code_class() -> None:
@ -57,10 +59,11 @@ class TestClass:
f.write(code)
f.flush()
new_code = get_code(
new_code, contextual_dunder_methods = get_code(
[FunctionToOptimize("test", f.name, [FunctionParent("TestClass", "ClassDef")])],
)
assert new_code == expected
assert contextual_dunder_methods == {("TestClass", "__init__")}
def test_get_code_bubble_sort_class() -> None:
@ -107,10 +110,14 @@ class BubbleSortClass:
f.write(code)
f.flush()
new_code = get_code(
new_code, contextual_dunder_methods = get_code(
[FunctionToOptimize("sorter", f.name, [FunctionParent("BubbleSortClass", "ClassDef")])],
)
assert new_code == expected
assert contextual_dunder_methods == {
("BubbleSortClass", "__init__"),
("BubbleSortClass", "__call__"),
}
def test_get_code_indent() -> None:
@ -168,7 +175,7 @@ def non():
with tempfile.NamedTemporaryFile("w") as f:
f.write(code)
f.flush()
new_code = get_code(
new_code, contextual_dunder_methods = get_code(
[
FunctionToOptimize(
"sorter",
@ -183,6 +190,10 @@ def non():
],
)
assert new_code == expected
assert contextual_dunder_methods == {
("BubbleSortClass", "__init__"),
("BubbleSortClass", "__call__"),
}
expected2 = """class BubbleSortClass:
def __init__(self):
@ -205,7 +216,7 @@ def non():
with tempfile.NamedTemporaryFile("w") as f:
f.write(code)
f.flush()
new_code = get_code(
new_code, contextual_dunder_methods = get_code(
[
FunctionToOptimize(
"sorter",
@ -225,3 +236,7 @@ def non():
],
)
assert new_code == expected2
assert contextual_dunder_methods == {
("BubbleSortClass", "__init__"),
("BubbleSortClass", "__call__"),
}

View file

@ -10,7 +10,7 @@ from codeflash.optimization.function_context import (
class CustomType:
def __init__(self):
def __init__(self) -> None:
self.name = None
self.data: List[int] = []
@ -21,19 +21,19 @@ class CustomDataClass:
data: List[int] = field(default_factory=list)
def function_to_optimize(data: CustomType):
def function_to_optimize(data: CustomType) -> CustomType:
name = data.name
data.data.sort()
return data
def function_to_optimize2(data: CustomDataClass):
def function_to_optimize2(data: CustomDataClass) -> CustomDataClass:
name = data.name
data.data.sort()
return data
def test_function_context_includes_type_annotation():
def test_function_context_includes_type_annotation() -> None:
file_path = pathlib.Path(__file__).resolve()
a, dependent_functions = get_constrained_function_context_and_dependent_functions(
FunctionToOptimize("function_to_optimize", str(file_path), []),
@ -49,7 +49,7 @@ def test_function_context_includes_type_annotation():
assert dependent_functions[0][0].full_name == "CustomType"
def test_function_context_includes_type_annotation_dataclass():
def test_function_context_includes_type_annotation_dataclass() -> None:
file_path = pathlib.Path(__file__).resolve()
a, dependent_functions = get_constrained_function_context_and_dependent_functions(
FunctionToOptimize("function_to_optimize2", str(file_path), []),
@ -65,11 +65,14 @@ def test_function_context_includes_type_annotation_dataclass():
assert dependent_functions[0][0].full_name == "CustomDataClass"
def test_function_context_custom_datatype():
def test_function_context_custom_datatype() -> None:
project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize"
file_path = project_path / "math_utils.py"
code = get_code([FunctionToOptimize("cosine_similarity", str(file_path), [])])
code, contextual_dunder_methods = get_code(
[FunctionToOptimize("cosine_similarity", str(file_path), [])],
)
assert code is not None
assert contextual_dunder_methods == set()
a, dependent_functions = get_constrained_function_context_and_dependent_functions(
FunctionToOptimize("cosine_similarity", str(file_path), []),
str(project_path),