mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Fix the logic to get pre-existing objects from the code to optimize
Fixes the bug ivan ran into
This commit is contained in:
parent
25e6bf86a7
commit
be38b46e1e
3 changed files with 107 additions and 17 deletions
|
|
@ -9,6 +9,8 @@ from libcst.codemod import CodemodContext
|
|||
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
|
||||
from libcst.helpers import calculate_module_and_package
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libcst.helpers import ModuleNameAndPackage
|
||||
|
||||
|
|
@ -237,3 +239,22 @@ def extract_code(
|
|||
)
|
||||
return None, set()
|
||||
return edited_code, contextual_dunder_methods
|
||||
|
||||
|
||||
def find_preexisting_objects(source_code: str):
|
||||
"""Find all preexisting functions, classes or class methods in the source code"""
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = []
|
||||
try:
|
||||
module_node: ast.Module = ast.parse(source_code)
|
||||
except SyntaxError:
|
||||
logging.exception("find_preexisting_objects - Syntax error while parsing code")
|
||||
return preexisting_objects
|
||||
for node in module_node.body:
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
preexisting_objects.append((node.name, []))
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
preexisting_objects.append((node.name, []))
|
||||
for cnode in node.body:
|
||||
if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
preexisting_objects.append((cnode.name, [FunctionParent(node.name, "ClassDef")]))
|
||||
return preexisting_objects
|
||||
|
|
|
|||
|
|
@ -20,7 +20,11 @@ from codeflash.api.aiservice import (
|
|||
OptimizedCandidate,
|
||||
)
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code
|
||||
from codeflash.code_utils.code_extractor import (
|
||||
add_needed_imports_from_module,
|
||||
extract_code,
|
||||
find_preexisting_objects,
|
||||
)
|
||||
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
|
||||
from codeflash.code_utils.code_utils import (
|
||||
get_run_tmp_file,
|
||||
|
|
@ -606,11 +610,6 @@ class Optimizer:
|
|||
)
|
||||
if code_to_optimize is None:
|
||||
return Failure("Could not find function to optimize.")
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
(name, [FunctionParent(name=class_name, type="ClassDef")])
|
||||
for class_name, name in contextual_dunder_methods
|
||||
]
|
||||
preexisting_objects.append((function_to_optimize.function_name, function_to_optimize.parents))
|
||||
(
|
||||
helper_code,
|
||||
helper_functions,
|
||||
|
|
@ -653,14 +652,7 @@ class Optimizer:
|
|||
project_root,
|
||||
helper_functions,
|
||||
)
|
||||
preexisting_objects.extend(
|
||||
[
|
||||
(qualified_name_list[-1], ([FunctionParent(name=qualified_name_list[-2], type="ClassDef")]))
|
||||
if len(qualified_name_list := fn.qualified_name.split(".")) > 1
|
||||
else (qualified_name_list[-1], [])
|
||||
for fn in helper_functions
|
||||
],
|
||||
)
|
||||
preexisting_objects = find_preexisting_objects(code_to_optimize_with_helpers)
|
||||
contextual_dunder_methods.update(helper_dunder_methods)
|
||||
return Success(
|
||||
CodeOptimizationContext(
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from argparse import Namespace
|
|||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports, replace_functions_in_file
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
|
@ -946,7 +945,6 @@ def test_test_libcst_code_replacement13() -> None:
|
|||
assert new_code == original_code
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
def test_different_class_code_replacement():
|
||||
original_code = """from __future__ import annotations
|
||||
import sys
|
||||
|
|
@ -1118,6 +1116,8 @@ class TestResults(BaseModel):
|
|||
("__getitem__", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("get_test_pass_fail_report_by_type", [FunctionParent(name="TestResults", type="ClassDef")]),
|
||||
("TestType", []),
|
||||
("TestResults", []),
|
||||
("to_name", [FunctionParent(name="TestType", type="ClassDef")]),
|
||||
]
|
||||
|
||||
contextual_functions = {
|
||||
|
|
@ -1152,6 +1152,7 @@ class TestResults(BaseModel):
|
|||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).parent.resolve()),
|
||||
)
|
||||
|
||||
helper_functions_by_module_abspath = defaultdict(set)
|
||||
for helper_function in helper_functions:
|
||||
if helper_function.jedi_definition.type != "class":
|
||||
|
|
@ -1173,7 +1174,83 @@ class TestResults(BaseModel):
|
|||
project_root_path=str(Path(__file__).parent.resolve()),
|
||||
)
|
||||
|
||||
print("hi")
|
||||
assert (
|
||||
new_code
|
||||
== """from __future__ import annotations
|
||||
import sys
|
||||
from codeflash.verification.comparator import comparator
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Iterator
|
||||
|
||||
class TestType(Enum):
|
||||
EXISTING_UNIT_TEST = 1
|
||||
INSPIRED_REGRESSION = 2
|
||||
GENERATED_REGRESSION = 3
|
||||
REPLAY_TEST = 4
|
||||
|
||||
def to_name(self) -> str:
|
||||
names = {
|
||||
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
|
||||
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
|
||||
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
|
||||
TestType.REPLAY_TEST: "⏪ Replay Tests",
|
||||
}
|
||||
return names[self]
|
||||
|
||||
class TestResults(BaseModel):
|
||||
def __iter__(self) -> Iterator[FunctionTestInvocation]:
|
||||
return iter(self.test_results)
|
||||
def __len__(self) -> int:
|
||||
return len(self.test_results)
|
||||
def __getitem__(self, index: int) -> FunctionTestInvocation:
|
||||
return self.test_results[index]
|
||||
def __setitem__(self, index: int, value: FunctionTestInvocation) -> None:
|
||||
self.test_results[index] = value
|
||||
def __delitem__(self, index: int) -> None:
|
||||
del self.test_results[index]
|
||||
def __contains__(self, value: FunctionTestInvocation) -> bool:
|
||||
return value in self.test_results
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.test_results)
|
||||
def __eq__(self, other: object) -> bool:
|
||||
# Unordered comparison
|
||||
if type(self) != type(other):
|
||||
return False
|
||||
if len(self) != len(other):
|
||||
return False
|
||||
original_recursion_limit = sys.getrecursionlimit()
|
||||
for test_result in self:
|
||||
other_test_result = other.get_by_id(test_result.id)
|
||||
if other_test_result is None:
|
||||
return False
|
||||
|
||||
if original_recursion_limit < 5000:
|
||||
sys.setrecursionlimit(5000)
|
||||
if (
|
||||
test_result.file_name != other_test_result.file_name
|
||||
or test_result.did_pass != other_test_result.did_pass
|
||||
or test_result.runtime != other_test_result.runtime
|
||||
or test_result.test_framework != other_test_result.test_framework
|
||||
or test_result.test_type != other_test_result.test_type
|
||||
or not comparator(
|
||||
test_result.return_value,
|
||||
other_test_result.return_value,
|
||||
)
|
||||
):
|
||||
sys.setrecursionlimit(original_recursion_limit)
|
||||
return False
|
||||
sys.setrecursionlimit(original_recursion_limit)
|
||||
return True
|
||||
|
||||
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
|
||||
report = {test_type: {"passed": 0, "failed": 0} for test_type in TestType}
|
||||
for test_result in self.test_results:
|
||||
if test_result.test_type != TestType.EXISTING_UNIT_TEST or test_result.id.function_getting_tested:
|
||||
key = "passed" if test_result.did_pass else "failed"
|
||||
report[test_result.test_type][key] += 1
|
||||
return report"""
|
||||
)
|
||||
|
||||
|
||||
def test_code_replacement_type_annotation():
|
||||
|
|
|
|||
Loading…
Reference in a new issue