Merge branch 'main' into mypy-gha
This commit is contained in:
commit
cbd1b5b97a
5 changed files with 1318 additions and 103 deletions
|
|
@ -890,7 +890,7 @@ class Optimizer:
|
|||
if result.test_type == TestType.GENERATED_REGRESSION:
|
||||
generated_test_results.add(result)
|
||||
|
||||
total_timing = unittest_results.total_passed_runtime()
|
||||
total_timing = unittest_results.total_passed_runtime() # caution: doesn't handle the loop index
|
||||
|
||||
functions_to_remove = [
|
||||
result.id.test_function_name for result in generated_test_results.test_results if not result.did_pass
|
||||
|
|
|
|||
|
|
@ -86,7 +86,6 @@ isort = ">=5.11.0"
|
|||
dill = "^0.3.8"
|
||||
rich = "^13.8.1"
|
||||
lxml = "^5.3.0"
|
||||
crosshair-tool = ">=0.0.77"
|
||||
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import os
|
||||
|
||||
@dataclass
|
||||
class LLM:
|
||||
|
|
@ -10,7 +10,7 @@ class LLM:
|
|||
# name of the model deployment on Azure OpenAI Service
|
||||
@dataclass
|
||||
class GPT_4_OMNI(LLM):
|
||||
name: str = "gpt-4o-2"
|
||||
name: str = "gpt-4o-2" if os.environ.get("OPENAI_API_TYPE") == "azure" else "gpt-4o"
|
||||
max_tokens: int = 128000
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,17 @@ class FunctionCallNodeArguments:
|
|||
keywords: list[ast.keyword]
|
||||
|
||||
|
||||
# Mapping of RNG modules to their seed-setting functions
|
||||
RNG_MODULES_SEEDS = {
|
||||
"random": "seed",
|
||||
"numpy.random": "seed",
|
||||
"numpy": "seed",
|
||||
"torch": "manual_seed",
|
||||
"tensorflow": "set_seed",
|
||||
# Add more modules as needed
|
||||
}
|
||||
|
||||
|
||||
class FunctionImportedAsVisitor(ast.NodeVisitor):
|
||||
"""This checks if a function has been imported as an alias. We only care about the alias then.
|
||||
from numpy import array as np_array
|
||||
|
|
@ -71,6 +82,11 @@ class InjectPerfAndLogging(ast.NodeTransformer):
|
|||
self.individual_test_timeout = test_timeout
|
||||
self.test_module_path = test_module_path
|
||||
self.random_test = False
|
||||
self.random_modules: dict[str, str] = {} # alias -> module_name
|
||||
self.present_modules = set() # Tracks original module names or aliases from 'import' statements
|
||||
self.present_aliases = set() # Tracks aliases from 'from ... import ...' statements
|
||||
self.additional_imports = [] # Stores additional import nodes to be added
|
||||
|
||||
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
|
||||
self.class_name = function.top_level_parent_name
|
||||
self.helper_function_names = helper_function_names # Other functional dependencies that were injected
|
||||
|
|
@ -78,14 +94,43 @@ class InjectPerfAndLogging(ast.NodeTransformer):
|
|||
def visit_ImportFrom(self, node: ast.ImportFrom):
|
||||
if any([name.name in [self.only_function_name, self.class_name] for name in node.names]):
|
||||
return None # Remove the import of the function the test generation code
|
||||
if node.module == "random":
|
||||
self.random_test = True
|
||||
return [node, ast.Import(names=[ast.alias(name="random")])]
|
||||
module_name = node.module # e.g., 'numpy.random'
|
||||
if module_name is None:
|
||||
# Handle relative imports or cases where module_name is None
|
||||
return node
|
||||
|
||||
for alias in node.names: # e.g., 'rand'
|
||||
imported_name = alias.name # e.g., 'random', 'randint'
|
||||
module_alias = alias.asname or imported_name
|
||||
# Track the alias from 'from ... import ...' statements
|
||||
self.present_aliases.add(module_alias)
|
||||
if module_name in RNG_MODULES_SEEDS:
|
||||
self.random_test = True
|
||||
self.random_modules[module_alias] = module_name
|
||||
|
||||
if module_name not in self.present_modules:
|
||||
# Add an import statement for the module if it's not already imported
|
||||
self.present_modules.add(module_name)
|
||||
import_node = ast.Import(names=[ast.alias(name=module_name)])
|
||||
self.additional_imports.append(import_node)
|
||||
return node
|
||||
|
||||
def visit_Import(self, node: ast.Import) -> ast.Import:
|
||||
if any([alias.name == "random" for alias in node.names]):
|
||||
self.random_test = True
|
||||
for alias in node.names:
|
||||
module_name = alias.name
|
||||
module_alias = alias.asname or module_name
|
||||
self.present_modules.add(module_alias)
|
||||
if module_name in RNG_MODULES_SEEDS:
|
||||
self.random_test = True
|
||||
self.random_modules[module_alias] = module_name
|
||||
return node
|
||||
|
||||
def visit_Module(self, node: ast.Module) -> ast.Module:
|
||||
# Reset additional_imports at the start of each module visit
|
||||
self.additional_imports = []
|
||||
self.generic_visit(node)
|
||||
# Insert additional imports at the beginning of the module to avoid conflicts
|
||||
node.body = self.additional_imports + node.body
|
||||
return node
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> Optional[ast.ClassDef]:
|
||||
|
|
@ -220,6 +265,77 @@ class InjectPerfAndLogging(ast.NodeTransformer):
|
|||
),
|
||||
]
|
||||
|
||||
def create_seed_statement(self, alias: str, full_module_name: str) -> ast.stmt | None:
|
||||
seed_method = RNG_MODULES_SEEDS.get(full_module_name)
|
||||
if seed_method:
|
||||
if alias in self.present_modules:
|
||||
module_ref = alias
|
||||
else:
|
||||
module_ref = full_module_name
|
||||
|
||||
# Build the seed call using the module name
|
||||
module_parts = module_ref.split(".")
|
||||
module = ast.Name(id=module_parts[0], ctx=ast.Load())
|
||||
for part in module_parts[1:]:
|
||||
module = ast.Attribute(value=module, attr=part, ctx=ast.Load())
|
||||
# Split the seed method if it's nested (e.g., 'random.seed')
|
||||
seed_parts = seed_method.split(".")
|
||||
# Build the attribute chain
|
||||
seed_func = module
|
||||
for part in seed_parts:
|
||||
seed_func = ast.Attribute(value=seed_func, attr=part, ctx=ast.Load())
|
||||
# Create the call to set the seed
|
||||
seed_call = ast.Expr(value=ast.Call(func=seed_func, args=[ast.Constant(value=42)], keywords=[]))
|
||||
return seed_call
|
||||
return None
|
||||
|
||||
def find_target_function_call(self, node: ast.AST) -> ast.Call | None:
|
||||
for child in ast.walk(node):
|
||||
if isinstance(child, ast.Call) and self.is_target_function_node(child):
|
||||
return child
|
||||
return None
|
||||
|
||||
def process_statements(
|
||||
self,
|
||||
statements: list[ast.stmt],
|
||||
node_name: str,
|
||||
class_name: str | None,
|
||||
offset: int = 0,
|
||||
parent_index: Optional[str] = None,
|
||||
) -> None:
|
||||
i = 0
|
||||
while i < len(statements):
|
||||
stmt = statements[i]
|
||||
if parent_index is None:
|
||||
current_index = str(i + offset)
|
||||
else:
|
||||
current_index = parent_index + "_" + str(i)
|
||||
if isinstance(stmt, (ast.With, ast.For, ast.While, ast.If, ast.Try, ast.AsyncWith, ast.AsyncFor)):
|
||||
# Recursively process the body of the control flow statement
|
||||
self.process_statements(stmt.body, node_name, class_name, offset=0, parent_index=current_index)
|
||||
# Handle else, finally blocks, and exception handlers similarly
|
||||
if hasattr(stmt, "orelse"): # Some statements (like if, for, while) can have an else block
|
||||
self.process_statements(stmt.orelse, node_name, class_name, offset=0, parent_index=current_index)
|
||||
if hasattr(stmt, "finalbody"): # try statements can have a finally block
|
||||
self.process_statements(stmt.finalbody, node_name, class_name, offset=0, parent_index=current_index)
|
||||
if isinstance(stmt, ast.Try) and stmt.handlers:
|
||||
for idx, handler in enumerate(stmt.handlers):
|
||||
self.process_statements(
|
||||
handler.body, node_name, class_name, offset=0, parent_index=current_index
|
||||
)
|
||||
else:
|
||||
# Check if the statement contains a call to the target function
|
||||
test_node = self.find_target_function_call(stmt)
|
||||
if test_node is not None:
|
||||
# Replace the statement with the updated nodes
|
||||
updated_nodes = self.update_line_node(
|
||||
test_node, get_call_arguments(test_node), node_name, current_index, class_name
|
||||
)
|
||||
statements[i : i + 1] = updated_nodes
|
||||
i += len(updated_nodes) - 1 # Adjust index for inserted nodes
|
||||
|
||||
i += 1
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef, class_name: Optional[str] = None) -> Optional[ast.FunctionDef]:
|
||||
random_index = 0
|
||||
random_test_function = False
|
||||
|
|
@ -227,86 +343,65 @@ class InjectPerfAndLogging(ast.NodeTransformer):
|
|||
if node.name == self.only_function_name or node.name in self.helper_function_names:
|
||||
return None # Remove the re-definition of the function and its dependencies from the test generation code
|
||||
if node.name.startswith("test_"):
|
||||
node.body.insert(
|
||||
0,
|
||||
ast.Assign(
|
||||
targets=[ast.Name(id="codeflash_loop_index", ctx=ast.Store())],
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="int", ctx=ast.Load()),
|
||||
args=[
|
||||
ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
|
||||
),
|
||||
slice=ast.Constant(value="CODEFLASH_LOOP_INDEX"),
|
||||
ctx=ast.Load(),
|
||||
)
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
lineno=node.lineno + 1,
|
||||
col_offset=node.col_offset,
|
||||
),
|
||||
)
|
||||
self.present_modules = set() # Reset tracked modules
|
||||
self.present_aliases = set() # Reset tracked aliases
|
||||
new_body = []
|
||||
# Insert seed-setting statements
|
||||
for alias, full_module_name in self.random_modules.items():
|
||||
seed_statement = self.create_seed_statement(alias, full_module_name)
|
||||
if seed_statement:
|
||||
new_body.append(seed_statement)
|
||||
|
||||
i: int = len(node.body) - 1
|
||||
while i >= 0:
|
||||
line_node = node.body[i]
|
||||
did_delete = False
|
||||
# TODO: Validate if the functional call actually did not raise any exceptions
|
||||
if isinstance(line_node, ast.Import):
|
||||
for name in line_node.names:
|
||||
if name.name == "random":
|
||||
random_test_function = True
|
||||
random_index = i + 1
|
||||
i -= 1
|
||||
continue
|
||||
if isinstance(line_node, ast.ImportFrom):
|
||||
if line_node.module == "random":
|
||||
random_test_function = True
|
||||
random_index = i + 1
|
||||
random_importfrom = True
|
||||
i -= 1
|
||||
continue
|
||||
|
||||
if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)):
|
||||
j = len(line_node.body) - 1
|
||||
while j >= 0:
|
||||
with_line_node: ast.stmt = line_node.body[j]
|
||||
with_node: Optional[ast.Call] = next(
|
||||
(node for node in ast.walk(with_line_node) if self.is_target_function_node(node)), None
|
||||
)
|
||||
if with_node is not None:
|
||||
line_node.body[j : j + 1] = self.update_line_node(
|
||||
with_node, get_call_arguments(with_node), node.name, str(i) + "_" + str(j), class_name
|
||||
)
|
||||
did_delete = True
|
||||
j -= 1
|
||||
else:
|
||||
test_node: Optional[ast.Call] = next(
|
||||
(node for node in ast.walk(line_node) if self.is_target_function_node(node)), None
|
||||
)
|
||||
if test_node is not None:
|
||||
node.body[i : i + 1] = self.update_line_node(
|
||||
test_node, get_call_arguments(test_node), node.name, str(i), class_name
|
||||
)
|
||||
did_delete = True
|
||||
i -= 1
|
||||
if self.random_test or random_test_function:
|
||||
node.body.insert(
|
||||
random_index,
|
||||
ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="random", ctx=ast.Load()), attr="seed", ctx=ast.Load()
|
||||
loop_index_assign = ast.Assign(
|
||||
targets=[ast.Name(id="codeflash_loop_index", ctx=ast.Store())],
|
||||
value=ast.Call(
|
||||
func=ast.Name(id="int", ctx=ast.Load()),
|
||||
args=[
|
||||
ast.Subscript(
|
||||
value=ast.Attribute(
|
||||
value=ast.Name(id="os", ctx=ast.Load()), attr="environ", ctx=ast.Load()
|
||||
),
|
||||
args=[ast.Constant(value=42)],
|
||||
keywords=[],
|
||||
slice=ast.Constant(value="CODEFLASH_LOOP_INDEX"),
|
||||
ctx=ast.Load(),
|
||||
)
|
||||
),
|
||||
)
|
||||
if random_importfrom:
|
||||
node.body.insert(random_index, ast.Import(names=[ast.alias(name="random")]))
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
lineno=node.lineno + 1,
|
||||
col_offset=node.col_offset,
|
||||
)
|
||||
new_body.append(loop_index_assign)
|
||||
offset = len(new_body)
|
||||
# Add the original function body
|
||||
# new_body.extend(node.body)
|
||||
|
||||
# Process each statement in the function body
|
||||
for stmt in node.body:
|
||||
# Check if the statement is an import of a random module
|
||||
if isinstance(stmt, (ast.Import, ast.ImportFrom)):
|
||||
self.visit(stmt) # To populate self.random_modules
|
||||
new_body.append(stmt)
|
||||
# Insert seed-setting statement after import
|
||||
alias_list = list(self.random_modules.keys())
|
||||
if alias_list:
|
||||
alias = alias_list[-1] # Get the last alias added
|
||||
full_module_name = self.random_modules[alias]
|
||||
seed_statement = self.create_seed_statement(alias, full_module_name)
|
||||
if seed_statement:
|
||||
new_body.append(seed_statement)
|
||||
# Remove the alias from random_modules to prevent duplicate seeds
|
||||
del self.random_modules[alias]
|
||||
continue
|
||||
|
||||
# For other statements, process them recursively
|
||||
self.generic_visit(stmt)
|
||||
stmt_list = [stmt]
|
||||
self.process_statements(stmt_list, node.name, class_name, offset=offset)
|
||||
new_body.extend(stmt_list)
|
||||
offset += len(stmt_list) # Update offset for each new statement added
|
||||
# Update the function body
|
||||
node.body = new_body
|
||||
|
||||
if self.test_framework == "unittest":
|
||||
# TODO: Make sure that if the test times out,
|
||||
# the test's time is excluded from the total time calculation and comparison
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue