Fix the logic to get pre-existing objects from the code to optimize

Fixes the bug ivan ran into
This commit is contained in:
Saurabh Misra 2024-07-09 17:22:30 -07:00
parent 25e6bf86a7
commit be38b46e1e
3 changed files with 107 additions and 17 deletions

View file

@ -9,6 +9,8 @@ from libcst.codemod import CodemodContext
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
from libcst.helpers import calculate_module_and_package
from codeflash.discovery.functions_to_optimize import FunctionParent
if TYPE_CHECKING:
from libcst.helpers import ModuleNameAndPackage
@ -237,3 +239,22 @@ def extract_code(
)
return None, set()
return edited_code, contextual_dunder_methods
def find_preexisting_objects(source_code: str):
"""Find all preexisting functions, classes or class methods in the source code"""
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
try:
module_node: ast.Module = ast.parse(source_code)
except SyntaxError:
logging.exception("find_preexisting_objects - Syntax error while parsing code")
return preexisting_objects
for node in module_node.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
preexisting_objects.append((node.name, []))
elif isinstance(node, ast.ClassDef):
preexisting_objects.append((node.name, []))
for cnode in node.body:
if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
preexisting_objects.append((cnode.name, [FunctionParent(node.name, "ClassDef")]))
return preexisting_objects

View file

@ -20,7 +20,11 @@ from codeflash.api.aiservice import (
OptimizedCandidate,
)
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code
from codeflash.code_utils.code_extractor import (
add_needed_imports_from_module,
extract_code,
find_preexisting_objects,
)
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.code_utils.code_utils import (
get_run_tmp_file,
@ -606,11 +610,6 @@ class Optimizer:
)
if code_to_optimize is None:
return Failure("Could not find function to optimize.")
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
(name, [FunctionParent(name=class_name, type="ClassDef")])
for class_name, name in contextual_dunder_methods
]
preexisting_objects.append((function_to_optimize.function_name, function_to_optimize.parents))
(
helper_code,
helper_functions,
@ -653,14 +652,7 @@ class Optimizer:
project_root,
helper_functions,
)
preexisting_objects.extend(
[
(qualified_name_list[-1], ([FunctionParent(name=qualified_name_list[-2], type="ClassDef")]))
if len(qualified_name_list := fn.qualified_name.split(".")) > 1
else (qualified_name_list[-1], [])
for fn in helper_functions
],
)
preexisting_objects = find_preexisting_objects(code_to_optimize_with_helpers)
contextual_dunder_methods.update(helper_dunder_methods)
return Success(
CodeOptimizationContext(

View file

@ -6,7 +6,6 @@ from argparse import Namespace
from collections import defaultdict
from pathlib import Path
import pytest
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports, replace_functions_in_file
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
from codeflash.optimization.optimizer import Optimizer
@ -946,7 +945,6 @@ def test_test_libcst_code_replacement13() -> None:
assert new_code == original_code
@pytest.mark.skip()
def test_different_class_code_replacement():
original_code = """from __future__ import annotations
import sys
@ -1118,6 +1116,8 @@ class TestResults(BaseModel):
("__getitem__", [FunctionParent(name="TestResults", type="ClassDef")]),
("get_test_pass_fail_report_by_type", [FunctionParent(name="TestResults", type="ClassDef")]),
("TestType", []),
("TestResults", []),
("to_name", [FunctionParent(name="TestType", type="ClassDef")]),
]
contextual_functions = {
@ -1152,6 +1152,7 @@ class TestResults(BaseModel):
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).parent.resolve()),
)
helper_functions_by_module_abspath = defaultdict(set)
for helper_function in helper_functions:
if helper_function.jedi_definition.type != "class":
@ -1173,7 +1174,83 @@ class TestResults(BaseModel):
project_root_path=str(Path(__file__).parent.resolve()),
)
print("hi")
assert (
new_code
== """from __future__ import annotations
import sys
from codeflash.verification.comparator import comparator
from enum import Enum
from pydantic import BaseModel
from typing import Iterator
class TestType(Enum):
EXISTING_UNIT_TEST = 1
INSPIRED_REGRESSION = 2
GENERATED_REGRESSION = 3
REPLAY_TEST = 4
def to_name(self) -> str:
names = {
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
TestType.REPLAY_TEST: "⏪ Replay Tests",
}
return names[self]
class TestResults(BaseModel):
def __iter__(self) -> Iterator[FunctionTestInvocation]:
return iter(self.test_results)
def __len__(self) -> int:
return len(self.test_results)
def __getitem__(self, index: int) -> FunctionTestInvocation:
return self.test_results[index]
def __setitem__(self, index: int, value: FunctionTestInvocation) -> None:
self.test_results[index] = value
def __delitem__(self, index: int) -> None:
del self.test_results[index]
def __contains__(self, value: FunctionTestInvocation) -> bool:
return value in self.test_results
def __bool__(self) -> bool:
return bool(self.test_results)
def __eq__(self, other: object) -> bool:
# Unordered comparison
if type(self) != type(other):
return False
if len(self) != len(other):
return False
original_recursion_limit = sys.getrecursionlimit()
for test_result in self:
other_test_result = other.get_by_id(test_result.id)
if other_test_result is None:
return False
if original_recursion_limit < 5000:
sys.setrecursionlimit(5000)
if (
test_result.file_name != other_test_result.file_name
or test_result.did_pass != other_test_result.did_pass
or test_result.runtime != other_test_result.runtime
or test_result.test_framework != other_test_result.test_framework
or test_result.test_type != other_test_result.test_type
or not comparator(
test_result.return_value,
other_test_result.return_value,
)
):
sys.setrecursionlimit(original_recursion_limit)
return False
sys.setrecursionlimit(original_recursion_limit)
return True
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
report = {test_type: {"passed": 0, "failed": 0} for test_type in TestType}
for test_result in self.test_results:
if test_result.test_type != TestType.EXISTING_UNIT_TEST or test_result.id.function_getting_tested:
key = "passed" if test_result.did_pass else "failed"
report[test_result.test_type][key] += 1
return report"""
)
def test_code_replacement_type_annotation():