Merge branch 'main' into mypy-gha

This commit is contained in:
RD 2024-11-13 15:35:48 -08:00 committed by GitHub
commit cbd1b5b97a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 1318 additions and 103 deletions

View file

@ -890,7 +890,7 @@ class Optimizer:
if result.test_type == TestType.GENERATED_REGRESSION:
generated_test_results.add(result)
total_timing = unittest_results.total_passed_runtime()
total_timing = unittest_results.total_passed_runtime() # caution: doesn't handle the loop index
functions_to_remove = [
result.id.test_function_name for result in generated_test_results.test_results if not result.did_pass

View file

@ -86,7 +86,6 @@ isort = ">=5.11.0"
dill = "^0.3.8"
rich = "^13.8.1"
lxml = "^5.3.0"
crosshair-tool = ">=0.0.77"
[tool.poetry.group.dev]

View file

@ -1,5 +1,5 @@
from pydantic.dataclasses import dataclass
import os
@dataclass
class LLM:
@ -10,7 +10,7 @@ class LLM:
# name of the model deployment on Azure OpenAI Service
@dataclass
class GPT_4_OMNI(LLM):
name: str = "gpt-4o-2"
name: str = "gpt-4o-2" if os.environ.get("OPENAI_API_TYPE") == "azure" else "gpt-4o"
max_tokens: int = 128000

View file

@ -14,6 +14,17 @@ class FunctionCallNodeArguments:
keywords: list[ast.keyword]
# Mapping of RNG modules to their seed-setting functions
RNG_MODULES_SEEDS = {
"random": "seed",
"numpy.random": "seed",
"numpy": "seed",
"torch": "manual_seed",
"tensorflow": "set_seed",
# Add more modules as needed
}
class FunctionImportedAsVisitor(ast.NodeVisitor):
"""This checks if a function has been imported as an alias. We only care about the alias then.
from numpy import array as np_array
@ -71,6 +82,11 @@ class InjectPerfAndLogging(ast.NodeTransformer):
self.individual_test_timeout = test_timeout
self.test_module_path = test_module_path
self.random_test = False
self.random_modules: dict[str, str] = {} # alias -> module_name
self.present_modules = set() # Tracks original module names or aliases from 'import' statements
self.present_aliases = set() # Tracks aliases from 'from ... import ...' statements
self.additional_imports = [] # Stores additional import nodes to be added
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
self.class_name = function.top_level_parent_name
self.helper_function_names = helper_function_names # Other functional dependencies that were injected
@ -78,14 +94,43 @@ class InjectPerfAndLogging(ast.NodeTransformer):
def visit_ImportFrom(self, node: ast.ImportFrom):
if any([name.name in [self.only_function_name, self.class_name] for name in node.names]):
return None # Remove the import of the function the test generation code
if node.module == "random":
self.random_test = True
return [node, ast.Import(names=[ast.alias(name="random")])]
module_name = node.module # e.g., 'numpy.random'
if module_name is None:
# Handle relative imports or cases where module_name is None
return node
for alias in node.names: # e.g., 'rand'
imported_name = alias.name # e.g., 'random', 'randint'
module_alias = alias.asname or imported_name
# Track the alias from 'from ... import ...' statements
self.present_aliases.add(module_alias)
if module_name in RNG_MODULES_SEEDS:
self.random_test = True
self.random_modules[module_alias] = module_name
if module_name not in self.present_modules:
# Add an import statement for the module if it's not already imported
self.present_modules.add(module_name)
import_node = ast.Import(names=[ast.alias(name=module_name)])
self.additional_imports.append(import_node)
return node
def visit_Import(self, node: ast.Import) -> ast.Import:
if any([alias.name == "random" for alias in node.names]):
self.random_test = True
for alias in node.names:
module_name = alias.name
module_alias = alias.asname or module_name
self.present_modules.add(module_alias)
if module_name in RNG_MODULES_SEEDS:
self.random_test = True
self.random_modules[module_alias] = module_name
return node
def visit_Module(self, node: ast.Module) -> ast.Module:
# Reset additional_imports at the start of each module visit
self.additional_imports = []
self.generic_visit(node)
# Insert additional imports at the beginning of the module to avoid conflicts
node.body = self.additional_imports + node.body
return node
def visit_ClassDef(self, node: ast.ClassDef) -> Optional[ast.ClassDef]:
@ -220,6 +265,77 @@ class InjectPerfAndLogging(ast.NodeTransformer):
),
]
def create_seed_statement(self, alias: str, full_module_name: str) -> ast.stmt | None:
seed_method = RNG_MODULES_SEEDS.get(full_module_name)
if seed_method:
if alias in self.present_modules:
module_ref = alias
else:
module_ref = full_module_name
# Build the seed call using the module name
module_parts = module_ref.split(".")
module = ast.Name(id=module_parts[0], ctx=ast.Load())
for part in module_parts[1:]:
module = ast.Attribute(value=module, attr=part, ctx=ast.Load())
# Split the seed method if it's nested (e.g., 'random.seed')
seed_parts = seed_method.split(".")
# Build the attribute chain
seed_func = module
for part in seed_parts:
seed_func = ast.Attribute(value=seed_func, attr=part, ctx=ast.Load())
# Create the call to set the seed
seed_call = ast.Expr(value=ast.Call(func=seed_func, args=[ast.Constant(value=42)], keywords=[]))
return seed_call
return None
def find_target_function_call(self, node: ast.AST) -> ast.Call | None:
for child in ast.walk(node):
if isinstance(child, ast.Call) and self.is_target_function_node(child):
return child
return None
def process_statements(
self,
statements: list[ast.stmt],
node_name: str,
class_name: str | None,
offset: int = 0,
parent_index: Optional[str] = None,
) -> None:
i = 0
while i < len(statements):
stmt = statements[i]
if parent_index is None:
current_index = str(i + offset)
else:
current_index = parent_index + "_" + str(i)
if isinstance(stmt, (ast.With, ast.For, ast.While, ast.If, ast.Try, ast.AsyncWith, ast.AsyncFor)):
# Recursively process the body of the control flow statement
self.process_statements(stmt.body, node_name, class_name, offset=0, parent_index=current_index)
# Handle else, finally blocks, and exception handlers similarly
if hasattr(stmt, "orelse"): # Some statements (like if, for, while) can have an else block
self.process_statements(stmt.orelse, node_name, class_name, offset=0, parent_index=current_index)
if hasattr(stmt, "finalbody"): # try statements can have a finally block
self.process_statements(stmt.finalbody, node_name, class_name, offset=0, parent_index=current_index)
if isinstance(stmt, ast.Try) and stmt.handlers:
for idx, handler in enumerate(stmt.handlers):
self.process_statements(
handler.body, node_name, class_name, offset=0, parent_index=current_index
)
else:
# Check if the statement contains a call to the target function
test_node = self.find_target_function_call(stmt)
if test_node is not None:
# Replace the statement with the updated nodes
updated_nodes = self.update_line_node(
test_node, get_call_arguments(test_node), node_name, current_index, class_name
)
statements[i : i + 1] = updated_nodes
i += len(updated_nodes) - 1 # Adjust index for inserted nodes
i += 1
def visit_FunctionDef(self, node: ast.FunctionDef, class_name: Optional[str] = None) -> Optional[ast.FunctionDef]:
random_index = 0
random_test_function = False
@ -227,86 +343,65 @@ class InjectPerfAndLogging(ast.NodeTransformer):
if node.name == self.only_function_name or node.name in self.helper_function_names:
return None # Remove the re-definition of the function and its dependencies from the test generation code
if node.name.startswith("test_"):
node.body.insert(
0,
ast.Assign(
targets=[ast.Name(id="codeflash_loop_index", ctx=ast.Store())],
value=ast.Call(
func=ast.Name(id="int", ctx=ast.Load()),
args=[
ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
),
slice=ast.Constant(value="CODEFLASH_LOOP_INDEX"),
ctx=ast.Load(),
)
],
keywords=[],
),
lineno=node.lineno + 1,
col_offset=node.col_offset,
),
)
self.present_modules = set() # Reset tracked modules
self.present_aliases = set() # Reset tracked aliases
new_body = []
# Insert seed-setting statements
for alias, full_module_name in self.random_modules.items():
seed_statement = self.create_seed_statement(alias, full_module_name)
if seed_statement:
new_body.append(seed_statement)
i: int = len(node.body) - 1
while i >= 0:
line_node = node.body[i]
did_delete = False
# TODO: Validate if the functional call actually did not raise any exceptions
if isinstance(line_node, ast.Import):
for name in line_node.names:
if name.name == "random":
random_test_function = True
random_index = i + 1
i -= 1
continue
if isinstance(line_node, ast.ImportFrom):
if line_node.module == "random":
random_test_function = True
random_index = i + 1
random_importfrom = True
i -= 1
continue
if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)):
j = len(line_node.body) - 1
while j >= 0:
with_line_node: ast.stmt = line_node.body[j]
with_node: Optional[ast.Call] = next(
(node for node in ast.walk(with_line_node) if self.is_target_function_node(node)), None
)
if with_node is not None:
line_node.body[j : j + 1] = self.update_line_node(
with_node, get_call_arguments(with_node), node.name, str(i) + "_" + str(j), class_name
)
did_delete = True
j -= 1
else:
test_node: Optional[ast.Call] = next(
(node for node in ast.walk(line_node) if self.is_target_function_node(node)), None
)
if test_node is not None:
node.body[i : i + 1] = self.update_line_node(
test_node, get_call_arguments(test_node), node.name, str(i), class_name
)
did_delete = True
i -= 1
if self.random_test or random_test_function:
node.body.insert(
random_index,
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="random", ctx=ast.Load()), attr="seed", ctx=ast.Load()
loop_index_assign = ast.Assign(
targets=[ast.Name(id="codeflash_loop_index", ctx=ast.Store())],
value=ast.Call(
func=ast.Name(id="int", ctx=ast.Load()),
args=[
ast.Subscript(
value=ast.Attribute(
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
),
args=[ast.Constant(value=42)],
keywords=[],
slice=ast.Constant(value="CODEFLASH_LOOP_INDEX"),
ctx=ast.Load(),
)
),
)
if random_importfrom:
node.body.insert(random_index, ast.Import(names=[ast.alias(name="random")]))
],
keywords=[],
),
lineno=node.lineno + 1,
col_offset=node.col_offset,
)
new_body.append(loop_index_assign)
offset = len(new_body)
# Add the original function body
# new_body.extend(node.body)
# Process each statement in the function body
for stmt in node.body:
# Check if the statement is an import of a random module
if isinstance(stmt, (ast.Import, ast.ImportFrom)):
self.visit(stmt) # To populate self.random_modules
new_body.append(stmt)
# Insert seed-setting statement after import
alias_list = list(self.random_modules.keys())
if alias_list:
alias = alias_list[-1] # Get the last alias added
full_module_name = self.random_modules[alias]
seed_statement = self.create_seed_statement(alias, full_module_name)
if seed_statement:
new_body.append(seed_statement)
# Remove the alias from random_modules to prevent duplicate seeds
del self.random_modules[alias]
continue
# For other statements, process them recursively
self.generic_visit(stmt)
stmt_list = [stmt]
self.process_statements(stmt_list, node.name, class_name, offset=offset)
new_body.extend(stmt_list)
offset += len(stmt_list) # Update offset for each new statement added
# Update the function body
node.body = new_body
if self.test_framework == "unittest":
# TODO: Make sure that if the test times out,
# the test's time is excluded from the total time calculation and comparison