- ast parse few other important places
- add a test for equality check postprocessing -remove some dead code
This commit is contained in:
parent
09850564f7
commit
7ce3509020
5 changed files with 61 additions and 100 deletions
|
|
@ -1,79 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import os
|
||||
from _ast import AsyncFunctionDef, ClassDef, FunctionDef
|
||||
from typing import Dict, List
|
||||
from _ast import AsyncFunctionDef, FunctionDef
|
||||
|
||||
import libcst as cst
|
||||
from libcst import CSTNode
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
# TODO these data models should prob be shared between aiservice and cli
|
||||
|
||||
|
||||
class ReturnStatementVisitor(cst.CSTVisitor):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.has_return_statement = False
|
||||
|
||||
def visit_Return(self, node: cst.Return) -> None:
|
||||
self.has_return_statement = True
|
||||
|
||||
|
||||
class FunctionVisitor(cst.CSTVisitor):
|
||||
METADATA_DEPENDENCIES = (cst.metadata.PositionProvider, cst.metadata.ParentNodeProvider)
|
||||
|
||||
def __init__(self, file_path: str) -> None:
|
||||
super().__init__()
|
||||
self.file_path: str = file_path
|
||||
self.functions: list[FunctionToOptimize] = []
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
return_visitor = ReturnStatementVisitor()
|
||||
node.visit(return_visitor)
|
||||
if return_visitor.has_return_statement:
|
||||
pos = self.get_metadata(cst.metadata.PositionProvider, node)
|
||||
parents: CSTNode | None = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||
ast_parents = []
|
||||
while parents is not None:
|
||||
if isinstance(parents, (cst.FunctionDef, cst.ClassDef)):
|
||||
ast_parents.append(
|
||||
FunctionParent(parents.name.value, parents.__class__.__name__),
|
||||
)
|
||||
parents = self.get_metadata(cst.metadata.ParentNodeProvider, parents, default=None)
|
||||
self.functions.append(
|
||||
FunctionToOptimize(
|
||||
function_name=node.name.value,
|
||||
file_path=self.file_path,
|
||||
parents=list(reversed(ast_parents)),
|
||||
starting_line=pos.start.line,
|
||||
ending_line=pos.end.line,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class FunctionWithReturnStatement(ast.NodeVisitor):
|
||||
def __init__(self, file_path: str) -> None:
|
||||
self.functions: list[FunctionToOptimize] = []
|
||||
self.ast_path: list[FunctionParent] = []
|
||||
self.file_path: str = file_path
|
||||
|
||||
def visit_FunctionDef(self, node: FunctionDef) -> None:
|
||||
# Check if the function has a return statement and add it to the list
|
||||
if function_has_return_statement(node):
|
||||
self.functions.append(FunctionToOptimize(node.name, self.file_path, self.ast_path[:]))
|
||||
# Continue visiting the body of the function to find nested functions
|
||||
self.generic_visit(node)
|
||||
|
||||
def generic_visit(self, node: ast.AST) -> None:
|
||||
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
|
||||
self.ast_path.append(FunctionParent(node.name, node.__class__.__name__))
|
||||
super().generic_visit(node)
|
||||
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
|
||||
self.ast_path.pop()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FunctionParent:
|
||||
name: str
|
||||
|
|
@ -100,32 +34,3 @@ class FunctionToOptimize:
|
|||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.file_path}:{'.'.join([p.name for p in self.parents]) + '.' if self.parents else ''}{self.function_name}"
|
||||
|
||||
|
||||
def get_all_files_and_functions(project_root_path: str) -> Dict[str, List[FunctionToOptimize]]:
|
||||
functions = {}
|
||||
for root, dirs, files in os.walk(project_root_path):
|
||||
for file in files:
|
||||
if not file.endswith(".py"):
|
||||
continue
|
||||
file_path = os.path.join(root, file)
|
||||
# Find all the functions in the file
|
||||
functions.update(find_all_functions_in_file(file_path))
|
||||
return functions
|
||||
|
||||
|
||||
def find_all_functions_in_file(file_path: str) -> Dict[str, List[FunctionToOptimize]]:
|
||||
functions: Dict[str, List[FunctionToOptimize]] = {}
|
||||
with open(file_path) as f:
|
||||
try:
|
||||
ast_module = ast.parse(f.read())
|
||||
except Exception:
|
||||
return functions
|
||||
function_name_visitor = FunctionWithReturnStatement(file_path)
|
||||
function_name_visitor.visit(ast_module)
|
||||
functions[file_path] = function_name_visitor.functions
|
||||
return functions
|
||||
|
||||
|
||||
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
|
||||
return any(isinstance(node, ast.Return) for node in ast.walk(function_node))
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from optimizer.postprocess import (
|
|||
cleanup_explanations,
|
||||
dedup_and_sort_imports,
|
||||
deduplicate_optimizations,
|
||||
equality_check,
|
||||
filter_ellipsis_containing_code,
|
||||
fix_missing_docstring,
|
||||
optimizations_postprocessing_pipeline,
|
||||
|
|
@ -91,6 +92,57 @@ def test_postprocess_deduplicates():
|
|||
assert actual[0].cst_module.deep_equals(expected[0].cst_module)
|
||||
|
||||
|
||||
def test_equality_check():
|
||||
original_code = '''from nuitka.utils.SlotMetaClasses import getMetaClassBase
|
||||
|
||||
class NuitkaPluginBase(getMetaClassBase("Plugin", require_slots=False)):
|
||||
def isRequiredImplicitImport(self, module, full_name):
|
||||
"""Indicate whether an implicitly imported module should be accepted.
|
||||
|
||||
Notes:
|
||||
You may negate importing a module specified as "implicit import",
|
||||
although this is an unexpected event.
|
||||
|
||||
Args:
|
||||
module: the module object
|
||||
full_name: of the implicitly import module
|
||||
Returns:
|
||||
True or False
|
||||
"""
|
||||
# Virtual method, pylint: disable=no-self-use,unused-argument
|
||||
return True
|
||||
'''
|
||||
optimizations = [
|
||||
CodeExplanationAndID(
|
||||
libcst.parse_module(
|
||||
'''from nuitka.utils.SlotMetaClasses import getMetaClassBase
|
||||
|
||||
class NuitkaPluginBase(getMetaClassBase("Plugin", require_slots=False)):
|
||||
def isRequiredImplicitImport(self, module, full_name):
|
||||
"""Indicate whether an implicitly imported module should be accepted.
|
||||
|
||||
Notes:
|
||||
You may negate importing a module specified as "implicit import",
|
||||
although this is an unexpected event.
|
||||
|
||||
Args:
|
||||
module: the module object
|
||||
full_name: of the implicitly import module
|
||||
Returns:
|
||||
True or False
|
||||
"""
|
||||
# Return True for all implicitly imported modules
|
||||
return True
|
||||
''',
|
||||
),
|
||||
"Simplified print",
|
||||
"1",
|
||||
),
|
||||
]
|
||||
actual = equality_check(original_code, optimizations)
|
||||
assert len(actual) == 0
|
||||
|
||||
|
||||
def test_postprocess_bubble_sort():
|
||||
original_code = """
|
||||
def sorter(arr):
|
||||
|
|
|
|||
|
|
@ -3,9 +3,11 @@
|
|||
# TODO: This is only here as a temporary reference implementaion of how an early version of LLM inspired tests was written.
|
||||
# It didn't work very well. This should be improved significantly.
|
||||
import ast # used for detecting whether generated Python code is valid
|
||||
import platform
|
||||
from typing import List, Tuple
|
||||
|
||||
import openai # used for calling the OpenAI API
|
||||
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL
|
||||
from codeflash.code_utils.code_extractor import get_code
|
||||
from codeflash.code_utils.code_utils import ellipsis_in_ast, get_imports_from_file
|
||||
from codeflash.discovery.discover_unit_tests import TestsInFile
|
||||
|
|
@ -14,8 +16,6 @@ from codeflash.verification.gen_regression_tests import (
|
|||
print_messages,
|
||||
)
|
||||
|
||||
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL
|
||||
|
||||
|
||||
def regression_tests_from_function_with_inspiration(
|
||||
function_to_test: str, # Python function to test, as a string
|
||||
|
|
@ -31,6 +31,7 @@ def regression_tests_from_function_with_inspiration(
|
|||
execute_model: LLM = EXECUTE_MODEL, # model used to generate code in step 3
|
||||
temperature: float = 0.4, # temperature = 0 can sometimes get stuck in repetitive loops, so we use 0.4
|
||||
reruns_if_fail: int = 1, # if the output code cannot be parsed, this will re-run the function up to N times
|
||||
python_version: Tuple[int, int, int] = platform.python_version_tuple(),
|
||||
) -> str:
|
||||
"""Returns a unit test for a given Python function, using a 3-step GPT prompt."""
|
||||
# TODO: This step is exactly the same as the non-inspired test generator. Merge them into one to save on API calls
|
||||
|
|
@ -184,7 +185,7 @@ import {unit_test_package} # used for our unit tests
|
|||
tests_list = [imp for sublist in inspired_test_imports for imp in sublist]
|
||||
code = ast.unparse(tests_list) + "\n" + code
|
||||
try:
|
||||
module = ast.parse(code)
|
||||
module = ast.parse(code, feature_version=python_version[:2])
|
||||
if ellipsis_in_ast(module):
|
||||
# If the test generator is generating ellipsis, it is punting on generating
|
||||
# the concrete test cases and we should re-generate
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import platform
|
||||
|
||||
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
|
|
@ -18,8 +19,9 @@ def instrument_test_source(
|
|||
test_module_path: str,
|
||||
test_framework: str,
|
||||
test_timeout: int,
|
||||
python_version: tuple[int, int, int] = platform.python_version_tuple(),
|
||||
) -> str:
|
||||
module_node = ast.parse(test_source)
|
||||
module_node = ast.parse(test_source, feature_version=python_version[:2])
|
||||
new_module_node = InjectPerfAndLogging(
|
||||
function_to_optimize,
|
||||
helper_function_names=helper_function_names,
|
||||
|
|
|
|||
|
|
@ -420,6 +420,7 @@ async def testgen(
|
|||
test_module_path=data.test_module_path,
|
||||
test_framework=data.test_framework,
|
||||
test_timeout=data.test_timeout,
|
||||
python_version=python_version,
|
||||
),
|
||||
float_to_top=True,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue