- ast parse few other important places

- add a test for equality check postprocessing
-remove some dead code
This commit is contained in:
Saurabh Misra 2024-06-22 17:40:47 -07:00
parent 09850564f7
commit 7ce3509020
5 changed files with 61 additions and 100 deletions

View file

@ -1,79 +1,13 @@
from __future__ import annotations
import ast
import os
from _ast import AsyncFunctionDef, ClassDef, FunctionDef
from typing import Dict, List
from _ast import AsyncFunctionDef, FunctionDef
import libcst as cst
from libcst import CSTNode
from pydantic.dataclasses import dataclass
# TODO these data models should prob be shared between aiservice and cli
class ReturnStatementVisitor(cst.CSTVisitor):
def __init__(self) -> None:
super().__init__()
self.has_return_statement = False
def visit_Return(self, node: cst.Return) -> None:
self.has_return_statement = True
class FunctionVisitor(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.PositionProvider, cst.metadata.ParentNodeProvider)
def __init__(self, file_path: str) -> None:
super().__init__()
self.file_path: str = file_path
self.functions: list[FunctionToOptimize] = []
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
return_visitor = ReturnStatementVisitor()
node.visit(return_visitor)
if return_visitor.has_return_statement:
pos = self.get_metadata(cst.metadata.PositionProvider, node)
parents: CSTNode | None = self.get_metadata(cst.metadata.ParentNodeProvider, node)
ast_parents = []
while parents is not None:
if isinstance(parents, (cst.FunctionDef, cst.ClassDef)):
ast_parents.append(
FunctionParent(parents.name.value, parents.__class__.__name__),
)
parents = self.get_metadata(cst.metadata.ParentNodeProvider, parents, default=None)
self.functions.append(
FunctionToOptimize(
function_name=node.name.value,
file_path=self.file_path,
parents=list(reversed(ast_parents)),
starting_line=pos.start.line,
ending_line=pos.end.line,
),
)
class FunctionWithReturnStatement(ast.NodeVisitor):
def __init__(self, file_path: str) -> None:
self.functions: list[FunctionToOptimize] = []
self.ast_path: list[FunctionParent] = []
self.file_path: str = file_path
def visit_FunctionDef(self, node: FunctionDef) -> None:
# Check if the function has a return statement and add it to the list
if function_has_return_statement(node):
self.functions.append(FunctionToOptimize(node.name, self.file_path, self.ast_path[:]))
# Continue visiting the body of the function to find nested functions
self.generic_visit(node)
def generic_visit(self, node: ast.AST) -> None:
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
self.ast_path.append(FunctionParent(node.name, node.__class__.__name__))
super().generic_visit(node)
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
self.ast_path.pop()
@dataclass(frozen=True)
class FunctionParent:
name: str
@ -100,32 +34,3 @@ class FunctionToOptimize:
def __str__(self) -> str:
return f"{self.file_path}:{'.'.join([p.name for p in self.parents]) + '.' if self.parents else ''}{self.function_name}"
def get_all_files_and_functions(project_root_path: str) -> Dict[str, List[FunctionToOptimize]]:
functions = {}
for root, dirs, files in os.walk(project_root_path):
for file in files:
if not file.endswith(".py"):
continue
file_path = os.path.join(root, file)
# Find all the functions in the file
functions.update(find_all_functions_in_file(file_path))
return functions
def find_all_functions_in_file(file_path: str) -> Dict[str, List[FunctionToOptimize]]:
functions: Dict[str, List[FunctionToOptimize]] = {}
with open(file_path) as f:
try:
ast_module = ast.parse(f.read())
except Exception:
return functions
function_name_visitor = FunctionWithReturnStatement(file_path)
function_name_visitor.visit(ast_module)
functions[file_path] = function_name_visitor.functions
return functions
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
return any(isinstance(node, ast.Return) for node in ast.walk(function_node))

View file

@ -6,6 +6,7 @@ from optimizer.postprocess import (
cleanup_explanations,
dedup_and_sort_imports,
deduplicate_optimizations,
equality_check,
filter_ellipsis_containing_code,
fix_missing_docstring,
optimizations_postprocessing_pipeline,
@ -91,6 +92,57 @@ def test_postprocess_deduplicates():
assert actual[0].cst_module.deep_equals(expected[0].cst_module)
def test_equality_check():
original_code = '''from nuitka.utils.SlotMetaClasses import getMetaClassBase
class NuitkaPluginBase(getMetaClassBase("Plugin", require_slots=False)):
def isRequiredImplicitImport(self, module, full_name):
"""Indicate whether an implicitly imported module should be accepted.
Notes:
You may negate importing a module specified as "implicit import",
although this is an unexpected event.
Args:
module: the module object
full_name: of the implicitly import module
Returns:
True or False
"""
# Virtual method, pylint: disable=no-self-use,unused-argument
return True
'''
optimizations = [
CodeExplanationAndID(
libcst.parse_module(
'''from nuitka.utils.SlotMetaClasses import getMetaClassBase
class NuitkaPluginBase(getMetaClassBase("Plugin", require_slots=False)):
def isRequiredImplicitImport(self, module, full_name):
"""Indicate whether an implicitly imported module should be accepted.
Notes:
You may negate importing a module specified as "implicit import",
although this is an unexpected event.
Args:
module: the module object
full_name: of the implicitly import module
Returns:
True or False
"""
# Return True for all implicitly imported modules
return True
''',
),
"Simplified print",
"1",
),
]
actual = equality_check(original_code, optimizations)
assert len(actual) == 0
def test_postprocess_bubble_sort():
original_code = """
def sorter(arr):

View file

@ -3,9 +3,11 @@
# TODO: This is only here as a temporary reference implementaion of how an early version of LLM inspired tests was written.
# It didn't work very well. This should be improved significantly.
import ast # used for detecting whether generated Python code is valid
import platform
from typing import List, Tuple
import openai # used for calling the OpenAI API
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL
from codeflash.code_utils.code_extractor import get_code
from codeflash.code_utils.code_utils import ellipsis_in_ast, get_imports_from_file
from codeflash.discovery.discover_unit_tests import TestsInFile
@ -14,8 +16,6 @@ from codeflash.verification.gen_regression_tests import (
print_messages,
)
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL
def regression_tests_from_function_with_inspiration(
function_to_test: str, # Python function to test, as a string
@ -31,6 +31,7 @@ def regression_tests_from_function_with_inspiration(
execute_model: LLM = EXECUTE_MODEL, # model used to generate code in step 3
temperature: float = 0.4, # temperature = 0 can sometimes get stuck in repetitive loops, so we use 0.4
reruns_if_fail: int = 1, # if the output code cannot be parsed, this will re-run the function up to N times
python_version: Tuple[int, int, int] = platform.python_version_tuple(),
) -> str:
"""Returns a unit test for a given Python function, using a 3-step GPT prompt."""
# TODO: This step is exactly the same as the non-inspired test generator. Merge them into one to save on API calls
@ -184,7 +185,7 @@ import {unit_test_package} # used for our unit tests
tests_list = [imp for sublist in inspired_test_imports for imp in sublist]
code = ast.unparse(tests_list) + "\n" + code
try:
module = ast.parse(code)
module = ast.parse(code, feature_version=python_version[:2])
if ellipsis_in_ast(module):
# If the test generator is generating ellipsis, it is punting on generating
# the concrete test cases and we should re-generate

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import ast
import platform
from aiservice.models.functions_to_optimize import FunctionToOptimize
@ -18,8 +19,9 @@ def instrument_test_source(
test_module_path: str,
test_framework: str,
test_timeout: int,
python_version: tuple[int, int, int] = platform.python_version_tuple(),
) -> str:
module_node = ast.parse(test_source)
module_node = ast.parse(test_source, feature_version=python_version[:2])
new_module_node = InjectPerfAndLogging(
function_to_optimize,
helper_function_names=helper_function_names,

View file

@ -420,6 +420,7 @@ async def testgen(
test_module_path=data.test_module_path,
test_framework=data.test_framework,
test_timeout=data.test_timeout,
python_version=python_version,
),
float_to_top=True,
)