mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
WIP, blocked by debugger from fixing last failing test (bubble sort in class).
This commit is contained in:
parent
0baa7f0b8a
commit
b092d1b84f
8 changed files with 194 additions and 94 deletions
|
|
@ -3,26 +3,32 @@ from __future__ import annotations
|
|||
import ast
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
|
||||
def get_code(functions_to_optimize: list[FunctionToOptimize]) -> Optional[str]:
|
||||
"""Returns the code for a class or functions in a file."""
|
||||
def get_code(
|
||||
functions_to_optimize: list[FunctionToOptimize],
|
||||
) -> tuple[str | None, set[tuple[str, str]]]:
|
||||
"""Return the code for a class or functions in a file."""
|
||||
file_path: str = functions_to_optimize[0].file_path
|
||||
class_skeleton: set[tuple[int, int]] = set()
|
||||
class_skeleton: set[tuple[int, int | None]] = set()
|
||||
contextual_dunder_methods: set[tuple[str, str]] = set()
|
||||
target_code: str = ""
|
||||
|
||||
def find_target(
|
||||
node_list: list[ast.stmt],
|
||||
name_parts: tuple[str, str] | tuple[str],
|
||||
) -> Optional[ast.AST]:
|
||||
target: Optional[
|
||||
ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Assign | ast.AnnAssign
|
||||
] = None
|
||||
node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Assign | ast.AnnAssign
|
||||
|
||||
) -> ast.AST | None:
|
||||
target: (
|
||||
ast.FunctionDef
|
||||
| ast.AsyncFunctionDef
|
||||
| ast.ClassDef
|
||||
| ast.Assign
|
||||
| ast.AnnAssign
|
||||
| None
|
||||
) = None
|
||||
node: ast.stmt
|
||||
for node in node_list:
|
||||
if (
|
||||
# The many mypy issues will be fixed once this code moves to the backend,
|
||||
|
|
@ -57,17 +63,19 @@ def get_code(functions_to_optimize: list[FunctionToOptimize]) -> Optional[str]:
|
|||
if isinstance(cbody[0], ast.expr): # Is a docstring
|
||||
class_skeleton.add((cbody[0].lineno, cbody[0].end_lineno))
|
||||
cbody = cbody[1:]
|
||||
cnode: ast.FunctionDef | ast.AsyncFunctionDef
|
||||
cnode: ast.stmt
|
||||
for cnode in cbody:
|
||||
# Collect all dunder methods.
|
||||
cnode_name: str
|
||||
if (
|
||||
isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
and len(cnode_name := cnode.name) > 4
|
||||
and cnode_name != name_parts[1]
|
||||
and cnode_name.isascii()
|
||||
and cnode_name.startswith("__")
|
||||
and cnode_name.endswith("__")
|
||||
):
|
||||
contextual_dunder_methods.add((target.name, cnode_name))
|
||||
class_skeleton.add((cnode.lineno, cnode.end_lineno))
|
||||
|
||||
return find_target(target.body, name_parts[1:])
|
||||
|
|
@ -80,7 +88,7 @@ def get_code(functions_to_optimize: list[FunctionToOptimize]) -> Optional[str]:
|
|||
module_node = ast.parse(source_code)
|
||||
except SyntaxError as e:
|
||||
logging.exception(f"get_code - Syntax error in code: {e}")
|
||||
return None
|
||||
return None, set()
|
||||
# Get the source code lines for the target node
|
||||
lines = source_code.splitlines(keepends=True)
|
||||
if len(functions_to_optimize[0].parents) == 1:
|
||||
|
|
@ -95,7 +103,7 @@ def get_code(functions_to_optimize: list[FunctionToOptimize]) -> Optional[str]:
|
|||
logging.error(
|
||||
f"Error: get_code does not support nesting function in functions: {functions_to_optimize[0].parents}",
|
||||
)
|
||||
return None
|
||||
return None, set()
|
||||
elif len(functions_to_optimize[0].parents) == 0:
|
||||
qualified_name_parts_list = [(functions_to_optimize[0].function_name,)]
|
||||
else:
|
||||
|
|
@ -103,9 +111,9 @@ def get_code(functions_to_optimize: list[FunctionToOptimize]) -> Optional[str]:
|
|||
"Error: get_code does not support more than one level of nesting for now. "
|
||||
f"Parents: {functions_to_optimize[0].parents}",
|
||||
)
|
||||
return None
|
||||
return None, set()
|
||||
for qualified_name_parts in qualified_name_parts_list:
|
||||
target_node: Optional[ast.AST] = find_target(module_node.body, qualified_name_parts)
|
||||
target_node: ast.AST | None = find_target(module_node.body, qualified_name_parts)
|
||||
if target_node is None:
|
||||
continue
|
||||
|
||||
|
|
@ -118,15 +126,15 @@ def get_code(functions_to_optimize: list[FunctionToOptimize]) -> Optional[str]:
|
|||
)
|
||||
else:
|
||||
target_code += "".join(lines[target_node.lineno - 1 : target_node.end_lineno])
|
||||
class_list: list[tuple[int, int]] = sorted(list(class_skeleton))
|
||||
class_list: list[tuple[int, int | None]] = sorted(class_skeleton)
|
||||
class_code = "".join(
|
||||
["".join(lines[s_lineno - 1 : e_lineno]) for (s_lineno, e_lineno) in class_list],
|
||||
)
|
||||
return class_code + target_code
|
||||
return class_code + target_code, contextual_dunder_methods
|
||||
|
||||
|
||||
def get_code_no_skeleton(file_path: str, target_name: str) -> Optional[str]:
|
||||
"""Returns the code for a function in a file. Irrespective of class skeleton."""
|
||||
def get_code_no_skeleton(file_path: str, target_name: str) -> str | None:
|
||||
"""Return the code for a function in a file, irrespective of class skeleton."""
|
||||
with open(file_path, encoding="utf8") as file:
|
||||
source_code = file.read()
|
||||
|
||||
|
|
@ -165,15 +173,17 @@ def get_code_no_skeleton(file_path: str, target_name: str) -> Optional[str]:
|
|||
return target_code
|
||||
|
||||
|
||||
def extract_code(functions_to_optimize: list[FunctionToOptimize]) -> Optional[str]:
|
||||
edited_code: Optional[str] = get_code(functions_to_optimize)
|
||||
def extract_code(
|
||||
functions_to_optimize: list[FunctionToOptimize],
|
||||
) -> tuple[str | None, set[tuple[str, str]]]:
|
||||
edited_code, contextual_dunder_methods = get_code(functions_to_optimize)
|
||||
if edited_code is None:
|
||||
return None
|
||||
return None, set()
|
||||
try:
|
||||
compile(edited_code, "edited_code", "exec")
|
||||
except SyntaxError as e:
|
||||
logging.exception(
|
||||
f"extract_code - Syntax error in extracted optimization candidate code: {e}",
|
||||
)
|
||||
return None
|
||||
return edited_code
|
||||
return None, set()
|
||||
return edited_code, contextual_dunder_methods
|
||||
|
|
|
|||
|
|
@ -1,24 +1,34 @@
|
|||
from typing import List, Union, Optional, IO
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import IO
|
||||
|
||||
import libcst as cst
|
||||
from libcst import SimpleStatementLine, FunctionDef
|
||||
from libcst import FunctionDef
|
||||
|
||||
|
||||
class OptimFunctionCollector(cst.CSTVisitor):
|
||||
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
|
||||
|
||||
def __init__(self, function_name: str, preexisting_functions: Optional[List[str]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
function_name: str,
|
||||
class_name: str | None,
|
||||
immutable_methods: set[tuple[str, str]],
|
||||
preexisting_functions: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if preexisting_functions is None:
|
||||
preexisting_functions = []
|
||||
self.function_name = function_name
|
||||
self.optim_body: Union[FunctionDef, None] = None
|
||||
self.class_name = class_name
|
||||
self.optim_body: FunctionDef | None = None
|
||||
self.optim_new_class_functions: list[cst.FunctionDef] = []
|
||||
self.optim_new_functions: list[cst.FunctionDef] = []
|
||||
self.optim_imports: List[Union[cst.Import, cst.ImportFrom]] = []
|
||||
self.optim_imports: list[cst.SimpleStatementLine] = []
|
||||
self.preexisting_functions = preexisting_functions
|
||||
self.immutable_methods = immutable_methods.union({(self.class_name, self.function_name)})
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef):
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||
parent2 = None
|
||||
try:
|
||||
|
|
@ -39,17 +49,19 @@ class OptimFunctionCollector(cst.CSTVisitor):
|
|||
self.optim_new_functions.append(node)
|
||||
|
||||
def visit_ClassDef_body(self, node: cst.ClassDef) -> None:
|
||||
for class_node in node.body.body:
|
||||
if isinstance(class_node, cst.FunctionDef) and class_node.name.value not in [
|
||||
"__init__",
|
||||
self.function_name,
|
||||
]:
|
||||
self.optim_new_class_functions.append(class_node)
|
||||
for child_node in node.body.body:
|
||||
if (
|
||||
isinstance(child_node, cst.FunctionDef)
|
||||
and (
|
||||
node.name.value,
|
||||
child_node.name.value,
|
||||
)
|
||||
not in self.immutable_methods
|
||||
):
|
||||
self.optim_new_class_functions.append(child_node)
|
||||
|
||||
def leave_SimpleStatementLine(self, original_node: "SimpleStatementLine") -> None:
|
||||
if isinstance(original_node.body[0], cst.Import):
|
||||
self.optim_imports.append(original_node)
|
||||
elif isinstance(original_node.body[0], cst.ImportFrom):
|
||||
def leave_SimpleStatementLine(self, original_node: cst.SimpleStatementLine) -> None:
|
||||
if isinstance(original_node.body[0], (cst.Import, cst.ImportFrom)):
|
||||
self.optim_imports.append(original_node)
|
||||
|
||||
|
||||
|
|
@ -58,11 +70,11 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
self,
|
||||
function_name: str,
|
||||
optim_body: cst.FunctionDef,
|
||||
optim_new_class_functions: List[cst.FunctionDef],
|
||||
optim_imports: List[Union[cst.Import, cst.ImportFrom]],
|
||||
optim_new_functions,
|
||||
class_name=None,
|
||||
):
|
||||
optim_new_class_functions: list[cst.FunctionDef],
|
||||
optim_imports: list[cst.SimpleStatementLine],
|
||||
optim_new_functions: list[cst.FunctionDef],
|
||||
class_name: str | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.function_name = function_name
|
||||
self.optim_body = optim_body
|
||||
|
|
@ -77,7 +89,9 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
return False
|
||||
|
||||
def leave_FunctionDef(
|
||||
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
|
||||
self,
|
||||
original_node: cst.FunctionDef,
|
||||
updated_node: cst.FunctionDef,
|
||||
) -> cst.FunctionDef:
|
||||
if original_node.name.value == self.function_name and (
|
||||
self.depth == 0 or (self.depth == 1 and self.in_class)
|
||||
|
|
@ -93,7 +107,9 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
return self.in_class
|
||||
|
||||
def leave_ClassDef(
|
||||
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
|
||||
self,
|
||||
original_node: cst.ClassDef,
|
||||
updated_node: cst.ClassDef,
|
||||
) -> cst.ClassDef:
|
||||
self.depth -= 1
|
||||
if self.in_class and (self.depth == 0) and (original_node.name.value == self.class_name):
|
||||
|
|
@ -101,7 +117,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
return updated_node.with_changes(
|
||||
body=updated_node.body.with_changes(
|
||||
body=(list(updated_node.body.body) + self.optim_new_class_functions),
|
||||
)
|
||||
),
|
||||
)
|
||||
return updated_node
|
||||
|
||||
|
|
@ -123,7 +139,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
*node.body[: max_function_index + 1],
|
||||
*self.optim_new_functions,
|
||||
*node.body[max_function_index + 1 :],
|
||||
)
|
||||
),
|
||||
)
|
||||
elif class_index is not None:
|
||||
node = node.with_changes(
|
||||
|
|
@ -131,7 +147,7 @@ class OptimFunctionReplacer(cst.CSTTransformer):
|
|||
*node.body[: class_index + 1],
|
||||
*self.optim_new_functions,
|
||||
*node.body[class_index + 1 :],
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
node = node.with_changes(body=(*self.optim_new_functions, *node.body))
|
||||
|
|
@ -165,6 +181,7 @@ def replace_functions_in_file(
|
|||
original_function_names: list[str],
|
||||
optimized_code: str,
|
||||
preexisting_functions: list[str],
|
||||
immutable_methods: set[tuple[str, str]],
|
||||
) -> str:
|
||||
parsed_function_names = []
|
||||
for original_function_name in original_function_names:
|
||||
|
|
@ -179,14 +196,19 @@ def replace_functions_in_file(
|
|||
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
|
||||
|
||||
for i, (function_name, class_name) in enumerate(parsed_function_names):
|
||||
visitor = OptimFunctionCollector(function_name, preexisting_functions)
|
||||
visited = module.visit(visitor)
|
||||
visitor = OptimFunctionCollector(
|
||||
function_name,
|
||||
class_name,
|
||||
immutable_methods,
|
||||
preexisting_functions,
|
||||
)
|
||||
module.visit(visitor)
|
||||
|
||||
if visitor.optim_body is None and not preexisting_functions:
|
||||
continue
|
||||
elif visitor.optim_body is None:
|
||||
if visitor.optim_body is None:
|
||||
raise ValueError(f"Did not find the function {function_name} in the optimized code")
|
||||
optim_imports = [] if i > 0 else visitor.optim_imports
|
||||
optim_imports: list[cst.SimpleStatementLine] = [] if i > 0 else visitor.optim_imports
|
||||
|
||||
transformer = OptimFunctionReplacer(
|
||||
visitor.function_name,
|
||||
|
|
@ -208,15 +230,17 @@ def replace_function_definitions_in_module(
|
|||
optimized_code: str,
|
||||
module_abspath: str,
|
||||
preexisting_functions: list[str],
|
||||
immutable_methods: set[tuple[str, str]],
|
||||
) -> None:
|
||||
file: IO[str]
|
||||
with open(module_abspath, "r", encoding="utf8") as file:
|
||||
with open(module_abspath, encoding="utf8") as file:
|
||||
source_code: str = file.read()
|
||||
new_code: str = replace_functions_in_file(
|
||||
source_code,
|
||||
function_names,
|
||||
optimized_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
)
|
||||
with open(module_abspath, "w", encoding="utf8") as file:
|
||||
file.write(new_code)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
solved problem, please reach out to us at careers@codeflash.ai. We're hiring!
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
|
|
@ -9,7 +11,7 @@ import pathlib
|
|||
import uuid
|
||||
from argparse import SUPPRESS, ArgumentParser, Namespace
|
||||
from collections import defaultdict
|
||||
from typing import IO, Tuple, Union
|
||||
from typing import IO, Dict, Tuple, Union
|
||||
|
||||
import libcst as cst
|
||||
|
||||
|
|
@ -212,7 +214,9 @@ class Optimizer:
|
|||
pathlib.Path(get_run_tmp_file("test_return_values_0.sqlite")).unlink(
|
||||
missing_ok=True,
|
||||
)
|
||||
code_to_optimize = extract_code([function_to_optimize])
|
||||
code_to_optimize, contextual_dunder_methods = extract_code(
|
||||
[function_to_optimize],
|
||||
)
|
||||
if code_to_optimize is None:
|
||||
logging.error("Could not find function to optimize.")
|
||||
continue
|
||||
|
|
@ -247,18 +251,13 @@ class Optimizer:
|
|||
for df in dependent_methods
|
||||
]
|
||||
if len(optimizable_methods) > 1:
|
||||
code_to_optimize_with_dependents = (
|
||||
dependent_code + "\n" + extract_code(optimizable_methods)
|
||||
code_to_optimize, contextual_dunder_methods = extract_code(
|
||||
optimizable_methods,
|
||||
)
|
||||
else:
|
||||
code_to_optimize_with_dependents = (
|
||||
dependent_code + "\n" + code_to_optimize
|
||||
)
|
||||
else:
|
||||
code_to_optimize_with_dependents = dependent_code + "\n" + code_to_optimize
|
||||
if code_to_optimize_with_dependents is None:
|
||||
logging.error("Could not find function with dependents to optimize.")
|
||||
continue
|
||||
if code_to_optimize is None:
|
||||
logging.error("Could not find function to optimize.")
|
||||
continue
|
||||
code_to_optimize_with_dependents = dependent_code + "\n" + code_to_optimize
|
||||
preexisting_functions.extend(
|
||||
[fn[0].full_name.split(".")[-1] for fn in dependent_functions],
|
||||
)
|
||||
|
|
@ -323,7 +322,7 @@ class Optimizer:
|
|||
# TODO: Postprocess the optimized function to include the original docstring and such
|
||||
|
||||
best_optimization = []
|
||||
speedup_ratios = dict()
|
||||
speedup_ratios: Dict[str, float | None] = dict()
|
||||
optimized_runtimes = dict()
|
||||
is_correct = dict()
|
||||
|
||||
|
|
@ -346,6 +345,7 @@ class Optimizer:
|
|||
optimization.source_code,
|
||||
path,
|
||||
preexisting_functions,
|
||||
contextual_dunder_methods,
|
||||
)
|
||||
for (
|
||||
module_abspath,
|
||||
|
|
@ -356,6 +356,7 @@ class Optimizer:
|
|||
optimization.source_code,
|
||||
module_abspath,
|
||||
[],
|
||||
contextual_dunder_methods,
|
||||
)
|
||||
except (
|
||||
ValueError,
|
||||
|
|
@ -445,6 +446,7 @@ class Optimizer:
|
|||
optimized_code,
|
||||
path,
|
||||
preexisting_functions,
|
||||
contextual_dunder_methods,
|
||||
)
|
||||
for (
|
||||
module_abspath,
|
||||
|
|
@ -455,6 +457,7 @@ class Optimizer:
|
|||
optimized_code,
|
||||
module_abspath,
|
||||
[],
|
||||
contextual_dunder_methods,
|
||||
)
|
||||
explanation_final = Explanation(
|
||||
raw_explanation_message=best_optimization[1],
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ def get_type_annotation_context(
|
|||
node_parents[:-1],
|
||||
),
|
||||
],
|
||||
)
|
||||
)[0]
|
||||
if source_code:
|
||||
sources.append(
|
||||
(
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ warn_required_dynamic_aliases = true
|
|||
[tool.ruff.lint]
|
||||
select = ["ALL"]
|
||||
unfixable = ["F401"]
|
||||
ignore = ["ANN101"]
|
||||
|
||||
[tool.ruff.lint.flake8-type-checking]
|
||||
strict = true
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from codeflash.code_utils.code_replacer import replace_functions_in_file
|
||||
|
|
@ -47,11 +49,13 @@ print("Hello world")
|
|||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_functions: list[str] = ["new_function"]
|
||||
immutable_methods: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code,
|
||||
[function_name],
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -65,7 +69,7 @@ def totally_new_function(value):
|
|||
|
||||
def other_function(st):
|
||||
return(st * 2)
|
||||
|
||||
|
||||
class NewClass:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
|
@ -105,8 +109,13 @@ print("Hello world")
|
|||
|
||||
function_name: str = "NewClass.new_function"
|
||||
preexisting_functions: list[str] = ["new_function", "other_function"]
|
||||
immutable_methods: set[tuple[str, str]] = {("NewClass", "__init__")}
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code, [function_name], optim_code, preexisting_functions
|
||||
original_code,
|
||||
[function_name],
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -161,8 +170,13 @@ print("Salut monde")
|
|||
|
||||
function_names: list[str] = ["module.other_function"]
|
||||
preexisting_functions: list[str] = []
|
||||
immutable_methods: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code, function_names, optim_code, preexisting_functions
|
||||
original_code,
|
||||
function_names,
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -173,7 +187,7 @@ from typing import Optional
|
|||
|
||||
def totally_new_function(value):
|
||||
return value
|
||||
|
||||
|
||||
def yet_another_function(values):
|
||||
return len(values) + 2
|
||||
|
||||
|
|
@ -208,6 +222,7 @@ import libcst as cst
|
|||
from typing import Mandatory
|
||||
|
||||
print("Au revoir")
|
||||
|
||||
def yet_another_function(values):
|
||||
return len(values) + 2
|
||||
|
||||
|
|
@ -219,8 +234,13 @@ print("Salut monde")
|
|||
|
||||
function_names: list[str] = ["module.yet_another_function", "module.other_function"]
|
||||
preexisting_functions: list[str] = []
|
||||
immutable_methods: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code, function_names, optim_code, preexisting_functions
|
||||
original_code,
|
||||
function_names,
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -232,7 +252,7 @@ def test_test_libcst_code_replacement5() -> None:
|
|||
|
||||
def badsort(ploc):
|
||||
donothing(ploc)
|
||||
|
||||
|
||||
def supersort(doink):
|
||||
for i in range(len(doink)):
|
||||
fix(doink, i)
|
||||
|
|
@ -256,7 +276,7 @@ def sorter_deps(arr):
|
|||
|
||||
def badsort(ploc):
|
||||
donothing(ploc)
|
||||
|
||||
|
||||
def supersort(doink):
|
||||
for i in range(len(doink)):
|
||||
fix(doink, i)
|
||||
|
|
@ -264,8 +284,13 @@ def supersort(doink):
|
|||
|
||||
function_names: list[str] = ["sorter_deps"]
|
||||
preexisting_functions: list[str] = ["sorter_deps"]
|
||||
immutable_methods: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code, function_names, optim_code, preexisting_functions
|
||||
original_code,
|
||||
function_names,
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -343,11 +368,16 @@ print("Not cool")
|
|||
["other_function"],
|
||||
optim_code,
|
||||
["other_function", "yet_another_function", "blob"],
|
||||
set(),
|
||||
)
|
||||
assert new_main_code == expected_main
|
||||
|
||||
new_dependent_code: str = replace_functions_in_file(
|
||||
original_code_dependent, ["blob"], optim_code, []
|
||||
original_code_dependent,
|
||||
["blob"],
|
||||
optim_code,
|
||||
[],
|
||||
set(),
|
||||
)
|
||||
assert new_dependent_code == expected_dependent
|
||||
|
||||
|
|
@ -537,8 +567,17 @@ class CacheConfig(BaseConfig):
|
|||
"__init__",
|
||||
"from_config",
|
||||
]
|
||||
immutable_methods: set[tuple[str, str]] = {
|
||||
("CacheSimilarityEvalConfig", "__init__"),
|
||||
("CacheConfig", "__init__"),
|
||||
("CacheInitConfig", "__init__"),
|
||||
}
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code, function_names, optim_code, preexisting_functions
|
||||
original_code,
|
||||
function_names,
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -603,7 +642,12 @@ def test_test_libcst_code_replacement8() -> None:
|
|||
preexisting_functions: list[str] = [
|
||||
"_hamming_distance",
|
||||
]
|
||||
immutable_methods: set[tuple[str, str]] = set()
|
||||
new_code: str = replace_functions_in_file(
|
||||
original_code, function_names, optim_code, preexisting_functions
|
||||
original_code,
|
||||
function_names,
|
||||
optim_code,
|
||||
preexisting_functions,
|
||||
immutable_methods,
|
||||
)
|
||||
assert new_code == expected
|
||||
|
|
|
|||
|
|
@ -12,8 +12,9 @@ def test_get_code_function() -> None:
|
|||
f.write(code)
|
||||
f.flush()
|
||||
|
||||
new_code = get_code([FunctionToOptimize("test", f.name, [])])
|
||||
new_code, contextual_dunder_methods = get_code([FunctionToOptimize("test", f.name, [])])
|
||||
assert new_code == code
|
||||
assert contextual_dunder_methods == set()
|
||||
|
||||
|
||||
def test_get_code_property() -> None:
|
||||
|
|
@ -27,12 +28,13 @@ def test_get_code_property() -> None:
|
|||
f.write(code)
|
||||
f.flush()
|
||||
|
||||
new_code = get_code(
|
||||
new_code, contextual_dunder_methods = get_code(
|
||||
[
|
||||
FunctionToOptimize("test", f.name, [FunctionParent("TestClass", "ClassDef")]),
|
||||
],
|
||||
)
|
||||
assert new_code == code
|
||||
assert contextual_dunder_methods == {("TestClass", "__init__")}
|
||||
|
||||
|
||||
def test_get_code_class() -> None:
|
||||
|
|
@ -57,10 +59,11 @@ class TestClass:
|
|||
f.write(code)
|
||||
f.flush()
|
||||
|
||||
new_code = get_code(
|
||||
new_code, contextual_dunder_methods = get_code(
|
||||
[FunctionToOptimize("test", f.name, [FunctionParent("TestClass", "ClassDef")])],
|
||||
)
|
||||
assert new_code == expected
|
||||
assert contextual_dunder_methods == {("TestClass", "__init__")}
|
||||
|
||||
|
||||
def test_get_code_bubble_sort_class() -> None:
|
||||
|
|
@ -107,10 +110,14 @@ class BubbleSortClass:
|
|||
f.write(code)
|
||||
f.flush()
|
||||
|
||||
new_code = get_code(
|
||||
new_code, contextual_dunder_methods = get_code(
|
||||
[FunctionToOptimize("sorter", f.name, [FunctionParent("BubbleSortClass", "ClassDef")])],
|
||||
)
|
||||
assert new_code == expected
|
||||
assert contextual_dunder_methods == {
|
||||
("BubbleSortClass", "__init__"),
|
||||
("BubbleSortClass", "__call__"),
|
||||
}
|
||||
|
||||
|
||||
def test_get_code_indent() -> None:
|
||||
|
|
@ -168,7 +175,7 @@ def non():
|
|||
with tempfile.NamedTemporaryFile("w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
new_code = get_code(
|
||||
new_code, contextual_dunder_methods = get_code(
|
||||
[
|
||||
FunctionToOptimize(
|
||||
"sorter",
|
||||
|
|
@ -183,6 +190,10 @@ def non():
|
|||
],
|
||||
)
|
||||
assert new_code == expected
|
||||
assert contextual_dunder_methods == {
|
||||
("BubbleSortClass", "__init__"),
|
||||
("BubbleSortClass", "__call__"),
|
||||
}
|
||||
|
||||
expected2 = """class BubbleSortClass:
|
||||
def __init__(self):
|
||||
|
|
@ -205,7 +216,7 @@ def non():
|
|||
with tempfile.NamedTemporaryFile("w") as f:
|
||||
f.write(code)
|
||||
f.flush()
|
||||
new_code = get_code(
|
||||
new_code, contextual_dunder_methods = get_code(
|
||||
[
|
||||
FunctionToOptimize(
|
||||
"sorter",
|
||||
|
|
@ -225,3 +236,7 @@ def non():
|
|||
],
|
||||
)
|
||||
assert new_code == expected2
|
||||
assert contextual_dunder_methods == {
|
||||
("BubbleSortClass", "__init__"),
|
||||
("BubbleSortClass", "__call__"),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from codeflash.optimization.function_context import (
|
|||
|
||||
|
||||
class CustomType:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.name = None
|
||||
self.data: List[int] = []
|
||||
|
||||
|
|
@ -21,19 +21,19 @@ class CustomDataClass:
|
|||
data: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
def function_to_optimize(data: CustomType):
|
||||
def function_to_optimize(data: CustomType) -> CustomType:
|
||||
name = data.name
|
||||
data.data.sort()
|
||||
return data
|
||||
|
||||
|
||||
def function_to_optimize2(data: CustomDataClass):
|
||||
def function_to_optimize2(data: CustomDataClass) -> CustomDataClass:
|
||||
name = data.name
|
||||
data.data.sort()
|
||||
return data
|
||||
|
||||
|
||||
def test_function_context_includes_type_annotation():
|
||||
def test_function_context_includes_type_annotation() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
a, dependent_functions = get_constrained_function_context_and_dependent_functions(
|
||||
FunctionToOptimize("function_to_optimize", str(file_path), []),
|
||||
|
|
@ -49,7 +49,7 @@ def test_function_context_includes_type_annotation():
|
|||
assert dependent_functions[0][0].full_name == "CustomType"
|
||||
|
||||
|
||||
def test_function_context_includes_type_annotation_dataclass():
|
||||
def test_function_context_includes_type_annotation_dataclass() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
a, dependent_functions = get_constrained_function_context_and_dependent_functions(
|
||||
FunctionToOptimize("function_to_optimize2", str(file_path), []),
|
||||
|
|
@ -65,11 +65,14 @@ def test_function_context_includes_type_annotation_dataclass():
|
|||
assert dependent_functions[0][0].full_name == "CustomDataClass"
|
||||
|
||||
|
||||
def test_function_context_custom_datatype():
|
||||
def test_function_context_custom_datatype() -> None:
|
||||
project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize"
|
||||
file_path = project_path / "math_utils.py"
|
||||
code = get_code([FunctionToOptimize("cosine_similarity", str(file_path), [])])
|
||||
code, contextual_dunder_methods = get_code(
|
||||
[FunctionToOptimize("cosine_similarity", str(file_path), [])],
|
||||
)
|
||||
assert code is not None
|
||||
assert contextual_dunder_methods == set()
|
||||
a, dependent_functions = get_constrained_function_context_and_dependent_functions(
|
||||
FunctionToOptimize("cosine_similarity", str(file_path), []),
|
||||
str(project_path),
|
||||
|
|
|
|||
Loading…
Reference in a new issue