437 lines
19 KiB
Python
437 lines
19 KiB
Python
from __future__ import annotations
|
|
|
|
import ast
|
|
import os
|
|
import re
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
import libcst as cst
|
|
from libcst import MetadataWrapper
|
|
from libcst.metadata import PositionProvider
|
|
|
|
from codeflash.cli_cmds.console import logger
|
|
from codeflash.code_utils.time_utils import format_perf, format_time
|
|
from codeflash.languages.registry import get_language_support
|
|
from codeflash.models.models import GeneratedTests, GeneratedTestsList
|
|
from codeflash.result.critic import performance_gain
|
|
|
|
if TYPE_CHECKING:
|
|
from codeflash.models.models import InvocationId
|
|
|
|
|
|
class CommentMapper(ast.NodeVisitor):
|
|
def __init__(
|
|
self, test: GeneratedTests, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]
|
|
) -> None:
|
|
self.results: dict[int, str] = {}
|
|
self.test: GeneratedTests = test
|
|
self.original_runtimes = original_runtimes
|
|
self.optimized_runtimes = optimized_runtimes
|
|
self.abs_path = test.behavior_file_path.with_suffix("")
|
|
self.context_stack: list[str] = []
|
|
|
|
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
|
self.context_stack.append(node.name)
|
|
for inner_node in node.body:
|
|
if isinstance(inner_node, ast.FunctionDef):
|
|
self.visit_FunctionDef(inner_node)
|
|
elif isinstance(inner_node, ast.AsyncFunctionDef):
|
|
self.visit_AsyncFunctionDef(inner_node)
|
|
self.context_stack.pop()
|
|
return node
|
|
|
|
def get_comment(self, match_key: str) -> str:
|
|
# calculate speedup and output comment
|
|
original_time = self.original_runtimes[match_key]
|
|
optimized_time = self.optimized_runtimes[match_key]
|
|
perf_gain = format_perf(
|
|
abs(performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) * 100)
|
|
)
|
|
status = "slower" if optimized_time > original_time else "faster"
|
|
# Create the runtime comment
|
|
return f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
|
|
self._process_function_def_common(node)
|
|
return node
|
|
|
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctionDef:
|
|
self._process_function_def_common(node)
|
|
return node
|
|
|
|
def _process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
|
|
self.context_stack.append(node.name)
|
|
i = len(node.body) - 1
|
|
test_qualified_name = ".".join(self.context_stack)
|
|
key = test_qualified_name + "#" + str(self.abs_path)
|
|
while i >= 0:
|
|
line_node = node.body[i]
|
|
if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)):
|
|
j = len(line_node.body) - 1
|
|
while j >= 0:
|
|
compound_line_node: ast.stmt = line_node.body[j]
|
|
nodes_to_check = [compound_line_node]
|
|
nodes_to_check.extend(getattr(compound_line_node, "body", []))
|
|
for internal_node in nodes_to_check:
|
|
if isinstance(internal_node, (ast.stmt, ast.Assign)):
|
|
inv_id = str(i) + "_" + str(j)
|
|
match_key = key + "#" + inv_id
|
|
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
|
|
self.results[internal_node.lineno] = self.get_comment(match_key)
|
|
j -= 1
|
|
else:
|
|
inv_id = str(i)
|
|
match_key = key + "#" + inv_id
|
|
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
|
|
self.results[line_node.lineno] = self.get_comment(match_key)
|
|
i -= 1
|
|
self.context_stack.pop()
|
|
|
|
|
|
def get_fn_call_linenos(
|
|
test: GeneratedTests, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]
|
|
) -> dict[int, str]:
|
|
line_comment_ast_mapper = CommentMapper(test, original_runtimes, optimized_runtimes)
|
|
source_code = test.generated_original_test_source
|
|
tree = ast.parse(source_code)
|
|
line_comment_ast_mapper.visit(tree)
|
|
return line_comment_ast_mapper.results
|
|
|
|
|
|
class CommentAdder(cst.CSTTransformer):
|
|
"""Transformer that adds comments to specified lines."""
|
|
|
|
# Declare metadata dependencies
|
|
METADATA_DEPENDENCIES = (PositionProvider,)
|
|
|
|
def __init__(self, line_to_comments: dict[int, str]) -> None:
|
|
"""Initialize the transformer with target line numbers.
|
|
|
|
Args:
|
|
line_to_comments: Mapping of line numbers (1-indexed) to comments
|
|
|
|
"""
|
|
self.line_to_comments = line_to_comments
|
|
super().__init__()
|
|
|
|
def leave_SimpleStatementLine(
|
|
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
|
|
) -> cst.SimpleStatementLine:
|
|
"""Add comment to simple statement lines."""
|
|
pos = self.get_metadata(PositionProvider, original_node)
|
|
|
|
if pos and pos.start.line in self.line_to_comments:
|
|
# Create a comment with trailing whitespace
|
|
comment = cst.TrailingWhitespace(
|
|
whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment(self.line_to_comments[pos.start.line])
|
|
)
|
|
|
|
# Update the trailing whitespace of the line itself
|
|
return updated_node.with_changes(trailing_whitespace=comment)
|
|
|
|
return updated_node
|
|
|
|
def leave_SimpleStatementSuite(
|
|
self, original_node: cst.SimpleStatementSuite, updated_node: cst.SimpleStatementSuite
|
|
) -> cst.SimpleStatementSuite:
|
|
"""Add comment to simple statement suites (e.g., after if/for/while)."""
|
|
pos = self.get_metadata(PositionProvider, original_node)
|
|
|
|
if pos and pos.start.line in self.line_to_comments:
|
|
# Create a comment with trailing whitespace
|
|
comment = cst.TrailingWhitespace(
|
|
whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment(self.line_to_comments[pos.start.line])
|
|
)
|
|
|
|
# Update the trailing whitespace of the suite
|
|
return updated_node.with_changes(trailing_whitespace=comment)
|
|
|
|
return updated_node
|
|
|
|
|
|
def _is_python_file(file_path: Path) -> bool:
|
|
"""Check if a file is a Python file."""
|
|
return file_path.suffix == ".py"
|
|
|
|
|
|
# TODO:{self} Needs cleanup for jest logic in else block
|
|
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path) -> dict[str, int]:
|
|
unique_inv_ids: dict[str, int] = {}
|
|
logger.debug(f"[unique_inv_id] Processing {len(inv_id_runtimes)} invocation IDs")
|
|
for inv_id, runtimes in inv_id_runtimes.items():
|
|
test_qualified_name = (
|
|
inv_id.test_class_name + "." + inv_id.test_function_name # type: ignore[operator]
|
|
if inv_id.test_class_name
|
|
else inv_id.test_function_name
|
|
)
|
|
|
|
# Detect if test_module_path is a file path (like in js tests) or a Python module name
|
|
# File paths contain slashes, module names use dots
|
|
test_module_path = inv_id.test_module_path
|
|
if "/" in test_module_path or "\\" in test_module_path:
|
|
# Already a file path - use directly
|
|
abs_path = tests_project_rootdir / Path(test_module_path)
|
|
else:
|
|
# Check for Jest test file extensions (e.g., tests.fibonacci.test.ts)
|
|
# These need special handling to avoid converting .test.ts -> /test/ts
|
|
jest_test_extensions = (
|
|
".test.ts",
|
|
".test.js",
|
|
".test.tsx",
|
|
".test.jsx",
|
|
".spec.ts",
|
|
".spec.js",
|
|
".spec.tsx",
|
|
".spec.jsx",
|
|
".ts",
|
|
".js",
|
|
".tsx",
|
|
".jsx",
|
|
".mjs",
|
|
".mts",
|
|
)
|
|
matched_ext = None
|
|
for ext in jest_test_extensions:
|
|
if test_module_path.endswith(ext):
|
|
matched_ext = ext
|
|
break
|
|
|
|
if matched_ext:
|
|
# JavaScript/TypeScript: convert module-style path to file path
|
|
# "tests.fibonacci__perfonlyinstrumented.test.ts" -> "tests/fibonacci__perfonlyinstrumented.test.ts"
|
|
base_path = test_module_path[: -len(matched_ext)]
|
|
file_path = base_path.replace(".", os.sep) + matched_ext
|
|
# Check if the module path includes the tests directory name
|
|
tests_dir_name = tests_project_rootdir.name
|
|
if file_path.startswith((tests_dir_name + os.sep, tests_dir_name + "/")):
|
|
# Module path includes "tests." - use parent directory
|
|
abs_path = tests_project_rootdir.parent / Path(file_path)
|
|
else:
|
|
# Module path doesn't include tests dir - use tests root directly
|
|
abs_path = tests_project_rootdir / Path(file_path)
|
|
else:
|
|
# Python module name - convert dots to path separators and add .py
|
|
abs_path = tests_project_rootdir / Path(test_module_path.replace(".", os.sep)).with_suffix(".py")
|
|
|
|
abs_path_str = str(abs_path.resolve().with_suffix(""))
|
|
# Include both unit test and perf test paths for runtime annotations
|
|
# (performance test runtimes are used for annotations)
|
|
if ("__unit_test_" not in abs_path_str and "__perf_test_" not in abs_path_str) or not test_qualified_name:
|
|
logger.debug(f"[unique_inv_id] Skipping: path={abs_path_str}, test_qualified_name={test_qualified_name}")
|
|
continue
|
|
key = test_qualified_name + "#" + abs_path_str
|
|
parts = inv_id.iteration_id.split("_").__len__() # type: ignore[union-attr]
|
|
cur_invid = inv_id.iteration_id.split("_")[0] if parts < 3 else "_".join(inv_id.iteration_id.split("_")[:-1]) # type: ignore[union-attr]
|
|
match_key = key + "#" + cur_invid
|
|
logger.debug(f"[unique_inv_id] Adding key: {match_key} with runtime {min(runtimes)}")
|
|
if match_key not in unique_inv_ids:
|
|
unique_inv_ids[match_key] = 0
|
|
unique_inv_ids[match_key] += min(runtimes)
|
|
logger.debug(f"[unique_inv_id] Result has {len(unique_inv_ids)} entries")
|
|
return unique_inv_ids
|
|
|
|
|
|
def add_runtime_comments_to_generated_tests(
|
|
generated_tests: GeneratedTestsList,
|
|
original_runtimes: dict[InvocationId, list[int]],
|
|
optimized_runtimes: dict[InvocationId, list[int]],
|
|
tests_project_rootdir: Optional[Path] = None,
|
|
) -> GeneratedTestsList:
|
|
"""Add runtime performance comments to function calls in generated tests."""
|
|
original_runtimes_dict = unique_inv_id(original_runtimes, tests_project_rootdir or Path())
|
|
optimized_runtimes_dict = unique_inv_id(optimized_runtimes, tests_project_rootdir or Path())
|
|
# Process each generated test
|
|
modified_tests = []
|
|
for test in generated_tests.generated_tests:
|
|
is_python = _is_python_file(test.behavior_file_path)
|
|
|
|
if is_python:
|
|
# Use Python libcst-based comment insertion
|
|
try:
|
|
tree = cst.parse_module(test.generated_original_test_source)
|
|
wrapper = MetadataWrapper(tree)
|
|
line_to_comments = get_fn_call_linenos(test, original_runtimes_dict, optimized_runtimes_dict)
|
|
comment_adder = CommentAdder(line_to_comments)
|
|
modified_tree = wrapper.visit(comment_adder)
|
|
modified_source = modified_tree.code
|
|
modified_test = GeneratedTests(
|
|
generated_original_test_source=modified_source,
|
|
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
|
|
instrumented_perf_test_source=test.instrumented_perf_test_source,
|
|
behavior_file_path=test.behavior_file_path,
|
|
perf_file_path=test.perf_file_path,
|
|
)
|
|
modified_tests.append(modified_test)
|
|
except Exception as e:
|
|
# If parsing fails, keep the original test
|
|
logger.debug(f"Failed to add runtime comments to test: {e}")
|
|
modified_tests.append(test)
|
|
else:
|
|
try:
|
|
language_support = get_language_support(test.behavior_file_path)
|
|
modified_source = language_support.add_runtime_comments(
|
|
test.generated_original_test_source, original_runtimes_dict, optimized_runtimes_dict
|
|
)
|
|
modified_test = GeneratedTests(
|
|
generated_original_test_source=modified_source,
|
|
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
|
|
instrumented_perf_test_source=test.instrumented_perf_test_source,
|
|
behavior_file_path=test.behavior_file_path,
|
|
perf_file_path=test.perf_file_path,
|
|
)
|
|
modified_tests.append(modified_test)
|
|
except Exception as e:
|
|
logger.debug(f"Failed to add runtime comments to test: {e}")
|
|
modified_tests.append(test)
|
|
|
|
return GeneratedTestsList(generated_tests=modified_tests)
|
|
|
|
|
|
def remove_functions_from_generated_tests(
|
|
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
|
|
) -> GeneratedTestsList:
|
|
# Pre-compile patterns for all function names to remove
|
|
function_patterns = _compile_function_patterns(test_functions_to_remove)
|
|
new_generated_tests = []
|
|
|
|
for generated_test in generated_tests.generated_tests:
|
|
source = generated_test.generated_original_test_source
|
|
|
|
# Apply all patterns without redundant searches
|
|
for pattern in function_patterns:
|
|
# Use finditer and sub only if necessary to avoid unnecessary .search()/.sub() calls
|
|
for match in pattern.finditer(source):
|
|
# Skip if "@pytest.mark.parametrize" present
|
|
# Only the matched function's code is targeted
|
|
if "@pytest.mark.parametrize" in match.group(0):
|
|
continue
|
|
# Remove function from source
|
|
# If match, remove the function by substitution in the source
|
|
# Replace using start/end indices for efficiency
|
|
start, end = match.span()
|
|
source = source[:start] + source[end:]
|
|
# After removal, break since .finditer() is from left to right, and only one match expected per function in source
|
|
break
|
|
|
|
generated_test.generated_original_test_source = source
|
|
new_generated_tests.append(generated_test)
|
|
|
|
return GeneratedTestsList(generated_tests=new_generated_tests)
|
|
|
|
|
|
# Pre-compile all function removal regexes upfront for efficiency.
|
|
def _compile_function_patterns(test_functions_to_remove: list[str]) -> list[re.Pattern[str]]:
|
|
return [
|
|
re.compile(
|
|
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?(async\s+)?def\s+{re.escape(func)}\(.*?\):.*?(?=\n(async\s+)?def\s|$)",
|
|
re.DOTALL,
|
|
)
|
|
for func in test_functions_to_remove
|
|
]
|
|
|
|
|
|
# Patterns for normalizing codeflash imports (legacy -> npm package)
|
|
_CODEFLASH_REQUIRE_PATTERN = re.compile(
|
|
r"(const|let|var)\s+(\w+)\s*=\s*require\s*\(\s*['\"]\.?/?codeflash-jest-helper['\"]\s*\)"
|
|
)
|
|
_CODEFLASH_IMPORT_PATTERN = re.compile(r"import\s+(?:\*\s+as\s+)?(\w+)\s+from\s+['\"]\.?/?codeflash-jest-helper['\"]")
|
|
|
|
|
|
def normalize_codeflash_imports(source: str) -> str:
|
|
"""Normalize codeflash imports to use the npm package.
|
|
|
|
Replaces legacy local file imports:
|
|
const codeflash = require('./codeflash-jest-helper')
|
|
import codeflash from './codeflash-jest-helper'
|
|
|
|
With npm package imports:
|
|
const codeflash = require('codeflash')
|
|
|
|
Args:
|
|
source: JavaScript/TypeScript source code.
|
|
|
|
Returns:
|
|
Source code with normalized imports.
|
|
|
|
"""
|
|
# Replace CommonJS require
|
|
source = _CODEFLASH_REQUIRE_PATTERN.sub(r"\1 \2 = require('codeflash')", source)
|
|
# Replace ES module import
|
|
return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
|
|
|
|
|
|
def inject_test_globals(generated_tests: GeneratedTestsList, test_framework: str = "jest") -> GeneratedTestsList:
|
|
# TODO: inside the prompt tell the llm if it should import jest functions or it's already injected in the global window
|
|
"""Inject test globals into all generated tests.
|
|
|
|
Args:
|
|
generated_tests: List of generated tests.
|
|
test_framework: The test framework being used ("jest", "vitest", or "mocha").
|
|
|
|
Returns:
|
|
Generated tests with test globals injected.
|
|
|
|
"""
|
|
# we only inject test globals for esm modules
|
|
# Use vitest imports for vitest projects, jest imports for jest projects
|
|
if test_framework == "vitest":
|
|
global_import = "import { vi, describe, it, expect, beforeEach, afterEach, beforeAll, test } from 'vitest'\n"
|
|
else:
|
|
# Default to jest imports for jest and other frameworks
|
|
global_import = (
|
|
"import { jest, describe, it, expect, beforeEach, afterEach, beforeAll, test } from '@jest/globals'\n"
|
|
)
|
|
|
|
for test in generated_tests.generated_tests:
|
|
test.generated_original_test_source = global_import + test.generated_original_test_source
|
|
test.instrumented_behavior_test_source = global_import + test.instrumented_behavior_test_source
|
|
test.instrumented_perf_test_source = global_import + test.instrumented_perf_test_source
|
|
return generated_tests
|
|
|
|
|
|
def disable_ts_check(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
|
|
"""Disable TypeScript type checking in all generated tests.
|
|
|
|
Args:
|
|
generated_tests: List of generated tests.
|
|
|
|
Returns:
|
|
Generated tests with TypeScript type checking disabled.
|
|
|
|
"""
|
|
# we only inject test globals for esm modules
|
|
ts_nocheck = "// @ts-nocheck\n"
|
|
|
|
for test in generated_tests.generated_tests:
|
|
test.generated_original_test_source = ts_nocheck + test.generated_original_test_source
|
|
test.instrumented_behavior_test_source = ts_nocheck + test.instrumented_behavior_test_source
|
|
test.instrumented_perf_test_source = ts_nocheck + test.instrumented_perf_test_source
|
|
return generated_tests
|
|
|
|
|
|
def normalize_generated_tests_imports(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
|
|
"""Normalize codeflash imports in all generated tests.
|
|
|
|
Args:
|
|
generated_tests: List of generated tests.
|
|
|
|
Returns:
|
|
Generated tests with normalized imports.
|
|
|
|
"""
|
|
normalized_tests = []
|
|
for test in generated_tests.generated_tests:
|
|
# Only normalize JS/TS files
|
|
if test.behavior_file_path.suffix in (".js", ".ts", ".jsx", ".tsx", ".mjs", ".mts"):
|
|
normalized_test = GeneratedTests(
|
|
generated_original_test_source=normalize_codeflash_imports(test.generated_original_test_source),
|
|
instrumented_behavior_test_source=normalize_codeflash_imports(test.instrumented_behavior_test_source),
|
|
instrumented_perf_test_source=normalize_codeflash_imports(test.instrumented_perf_test_source),
|
|
behavior_file_path=test.behavior_file_path,
|
|
perf_file_path=test.perf_file_path,
|
|
)
|
|
normalized_tests.append(normalized_test)
|
|
else:
|
|
normalized_tests.append(test)
|
|
return GeneratedTestsList(generated_tests=normalized_tests)
|