Fix the logic to get pre-existing objects from the code to optimize
Fixes the bug ivan ran into
This commit is contained in:
parent
25e6bf86a7
commit
be38b46e1e
3 changed files with 107 additions and 17 deletions
|
|
@ -9,6 +9,8 @@ from libcst.codemod import CodemodContext
|
||||||
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
|
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
|
||||||
from libcst.helpers import calculate_module_and_package
|
from libcst.helpers import calculate_module_and_package
|
||||||
|
|
||||||
|
from codeflash.discovery.functions_to_optimize import FunctionParent
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from libcst.helpers import ModuleNameAndPackage
|
from libcst.helpers import ModuleNameAndPackage
|
||||||
|
|
||||||
|
|
@ -237,3 +239,22 @@ def extract_code(
|
||||||
)
|
)
|
||||||
return None, set()
|
return None, set()
|
||||||
return edited_code, contextual_dunder_methods
|
return edited_code, contextual_dunder_methods
|
||||||
|
|
||||||
|
|
||||||
|
def find_preexisting_objects(source_code: str):
|
||||||
|
"""Find all preexisting functions, classes or class methods in the source code"""
|
||||||
|
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
|
||||||
|
try:
|
||||||
|
module_node: ast.Module = ast.parse(source_code)
|
||||||
|
except SyntaxError:
|
||||||
|
logging.exception("find_preexisting_objects - Syntax error while parsing code")
|
||||||
|
return preexisting_objects
|
||||||
|
for node in module_node.body:
|
||||||
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||||
|
preexisting_objects.append((node.name, []))
|
||||||
|
elif isinstance(node, ast.ClassDef):
|
||||||
|
preexisting_objects.append((node.name, []))
|
||||||
|
for cnode in node.body:
|
||||||
|
if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||||
|
preexisting_objects.append((cnode.name, [FunctionParent(node.name, "ClassDef")]))
|
||||||
|
return preexisting_objects
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,11 @@ from codeflash.api.aiservice import (
|
||||||
OptimizedCandidate,
|
OptimizedCandidate,
|
||||||
)
|
)
|
||||||
from codeflash.code_utils import env_utils
|
from codeflash.code_utils import env_utils
|
||||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code
|
from codeflash.code_utils.code_extractor import (
|
||||||
|
add_needed_imports_from_module,
|
||||||
|
extract_code,
|
||||||
|
find_preexisting_objects,
|
||||||
|
)
|
||||||
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
|
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
|
||||||
from codeflash.code_utils.code_utils import (
|
from codeflash.code_utils.code_utils import (
|
||||||
get_run_tmp_file,
|
get_run_tmp_file,
|
||||||
|
|
@ -606,11 +610,6 @@ class Optimizer:
|
||||||
)
|
)
|
||||||
if code_to_optimize is None:
|
if code_to_optimize is None:
|
||||||
return Failure("Could not find function to optimize.")
|
return Failure("Could not find function to optimize.")
|
||||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
|
||||||
(name, [FunctionParent(name=class_name, type="ClassDef")])
|
|
||||||
for class_name, name in contextual_dunder_methods
|
|
||||||
]
|
|
||||||
preexisting_objects.append((function_to_optimize.function_name, function_to_optimize.parents))
|
|
||||||
(
|
(
|
||||||
helper_code,
|
helper_code,
|
||||||
helper_functions,
|
helper_functions,
|
||||||
|
|
@ -653,14 +652,7 @@ class Optimizer:
|
||||||
project_root,
|
project_root,
|
||||||
helper_functions,
|
helper_functions,
|
||||||
)
|
)
|
||||||
preexisting_objects.extend(
|
preexisting_objects = find_preexisting_objects(code_to_optimize_with_helpers)
|
||||||
[
|
|
||||||
(qualified_name_list[-1], ([FunctionParent(name=qualified_name_list[-2], type="ClassDef")]))
|
|
||||||
if len(qualified_name_list := fn.qualified_name.split(".")) > 1
|
|
||||||
else (qualified_name_list[-1], [])
|
|
||||||
for fn in helper_functions
|
|
||||||
],
|
|
||||||
)
|
|
||||||
contextual_dunder_methods.update(helper_dunder_methods)
|
contextual_dunder_methods.update(helper_dunder_methods)
|
||||||
return Success(
|
return Success(
|
||||||
CodeOptimizationContext(
|
CodeOptimizationContext(
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ from argparse import Namespace
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
|
||||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports, replace_functions_in_file
|
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports, replace_functions_in_file
|
||||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||||
from codeflash.optimization.optimizer import Optimizer
|
from codeflash.optimization.optimizer import Optimizer
|
||||||
|
|
@ -946,7 +945,6 @@ def test_test_libcst_code_replacement13() -> None:
|
||||||
assert new_code == original_code
|
assert new_code == original_code
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip()
|
|
||||||
def test_different_class_code_replacement():
|
def test_different_class_code_replacement():
|
||||||
original_code = """from __future__ import annotations
|
original_code = """from __future__ import annotations
|
||||||
import sys
|
import sys
|
||||||
|
|
@ -1118,6 +1116,8 @@ class TestResults(BaseModel):
|
||||||
("__getitem__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
("__getitem__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||||
("get_test_pass_fail_report_by_type", [FunctionParent(name="TestResults", type="ClassDef")]),
|
("get_test_pass_fail_report_by_type", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||||
("TestType", []),
|
("TestType", []),
|
||||||
|
("TestResults", []),
|
||||||
|
("to_name", [FunctionParent(name="TestType", type="ClassDef")]),
|
||||||
]
|
]
|
||||||
|
|
||||||
contextual_functions = {
|
contextual_functions = {
|
||||||
|
|
@ -1152,6 +1152,7 @@ class TestResults(BaseModel):
|
||||||
contextual_functions=contextual_functions,
|
contextual_functions=contextual_functions,
|
||||||
project_root_path=str(Path(__file__).parent.resolve()),
|
project_root_path=str(Path(__file__).parent.resolve()),
|
||||||
)
|
)
|
||||||
|
|
||||||
helper_functions_by_module_abspath = defaultdict(set)
|
helper_functions_by_module_abspath = defaultdict(set)
|
||||||
for helper_function in helper_functions:
|
for helper_function in helper_functions:
|
||||||
if helper_function.jedi_definition.type != "class":
|
if helper_function.jedi_definition.type != "class":
|
||||||
|
|
@ -1173,7 +1174,83 @@ class TestResults(BaseModel):
|
||||||
project_root_path=str(Path(__file__).parent.resolve()),
|
project_root_path=str(Path(__file__).parent.resolve()),
|
||||||
)
|
)
|
||||||
|
|
||||||
print("hi")
|
assert (
|
||||||
|
new_code
|
||||||
|
== """from __future__ import annotations
|
||||||
|
import sys
|
||||||
|
from codeflash.verification.comparator import comparator
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
class TestType(Enum):
|
||||||
|
EXISTING_UNIT_TEST = 1
|
||||||
|
INSPIRED_REGRESSION = 2
|
||||||
|
GENERATED_REGRESSION = 3
|
||||||
|
REPLAY_TEST = 4
|
||||||
|
|
||||||
|
def to_name(self) -> str:
|
||||||
|
names = {
|
||||||
|
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
|
||||||
|
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
|
||||||
|
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
|
||||||
|
TestType.REPLAY_TEST: "⏪ Replay Tests",
|
||||||
|
}
|
||||||
|
return names[self]
|
||||||
|
|
||||||
|
class TestResults(BaseModel):
|
||||||
|
def __iter__(self) -> Iterator[FunctionTestInvocation]:
|
||||||
|
return iter(self.test_results)
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.test_results)
|
||||||
|
def __getitem__(self, index: int) -> FunctionTestInvocation:
|
||||||
|
return self.test_results[index]
|
||||||
|
def __setitem__(self, index: int, value: FunctionTestInvocation) -> None:
|
||||||
|
self.test_results[index] = value
|
||||||
|
def __delitem__(self, index: int) -> None:
|
||||||
|
del self.test_results[index]
|
||||||
|
def __contains__(self, value: FunctionTestInvocation) -> bool:
|
||||||
|
return value in self.test_results
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return bool(self.test_results)
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
# Unordered comparison
|
||||||
|
if type(self) != type(other):
|
||||||
|
return False
|
||||||
|
if len(self) != len(other):
|
||||||
|
return False
|
||||||
|
original_recursion_limit = sys.getrecursionlimit()
|
||||||
|
for test_result in self:
|
||||||
|
other_test_result = other.get_by_id(test_result.id)
|
||||||
|
if other_test_result is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if original_recursion_limit < 5000:
|
||||||
|
sys.setrecursionlimit(5000)
|
||||||
|
if (
|
||||||
|
test_result.file_name != other_test_result.file_name
|
||||||
|
or test_result.did_pass != other_test_result.did_pass
|
||||||
|
or test_result.runtime != other_test_result.runtime
|
||||||
|
or test_result.test_framework != other_test_result.test_framework
|
||||||
|
or test_result.test_type != other_test_result.test_type
|
||||||
|
or not comparator(
|
||||||
|
test_result.return_value,
|
||||||
|
other_test_result.return_value,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
sys.setrecursionlimit(original_recursion_limit)
|
||||||
|
return False
|
||||||
|
sys.setrecursionlimit(original_recursion_limit)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
|
||||||
|
report = {test_type: {"passed": 0, "failed": 0} for test_type in TestType}
|
||||||
|
for test_result in self.test_results:
|
||||||
|
if test_result.test_type != TestType.EXISTING_UNIT_TEST or test_result.id.function_getting_tested:
|
||||||
|
key = "passed" if test_result.did_pass else "failed"
|
||||||
|
report[test_result.test_type][key] += 1
|
||||||
|
return report"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_code_replacement_type_annotation():
|
def test_code_replacement_type_annotation():
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue