fix: cherry-pick main improvements into omni-java branch

- Take main's JS improvements: Mocha CJS support, ESM/CJS handling,
  sanitize_mocha_imports, vitest benchmarking fixes
- Update instrument_existing_test API: remove test_string param, read from
  file internally (aligned across Python, JS, Java support classes)
- Take main's equivalence.py with pass_fail_only parameter
- Take main's models.py, critic.py, env_utils.py, replay_test.py fixes
- Take main's PythonFunctionOptimizer parse_line_profile fix
- Skip files where our branch has Java-specific additions main doesn't
  have (create_pr, discover_unit_tests, parse_test_output, optimizer,
  verification_utils, config_parser, cmd_init, detector, support.py
  protocol methods)
This commit is contained in:
Kevin Turcios 2026-03-03 23:59:26 -05:00
parent bccc02aade
commit af7ce7fce2
26 changed files with 901 additions and 176 deletions

View file

@ -6,6 +6,11 @@ import textwrap
from pathlib import Path
from typing import TYPE_CHECKING, Any
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.formatter import sort_imports
from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods
from codeflash.verification.verification_utils import get_test_file_path
if TYPE_CHECKING:
from collections.abc import Generator
@ -227,11 +232,6 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, max_run_count:
The number of replay tests generated
"""
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.formatter import sort_imports
from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods
from codeflash.verification.verification_utils import get_test_file_path
count = 0
try:
# Connect to the database

View file

@ -13,6 +13,7 @@ from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.formatter import format_code
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc
from codeflash.languages.registry import get_language_support_by_common_formatters
from codeflash.lsp.helpers import is_LSP_enabled
@ -22,11 +23,7 @@ def check_formatter_installed(
if not formatter_cmds or formatter_cmds[0] == "disabled":
return True
first_cmd = formatter_cmds[0]
# Fast path: avoid expensive shlex.split for simple strings without quotes
if " " not in first_cmd or ('"' not in first_cmd and "'" not in first_cmd):
cmd_tokens = first_cmd.split()
else:
cmd_tokens = shlex.split(first_cmd)
cmd_tokens = shlex.split(first_cmd) if isinstance(first_cmd, str) else [first_cmd]
if not cmd_tokens:
return True
@ -41,9 +38,6 @@ def check_formatter_installed(
)
return False
# Import here to avoid circular import
from codeflash.languages.registry import get_language_support_by_common_formatters
lang_support = get_language_support_by_common_formatters(formatter_cmds)
if not lang_support:
logger.debug(f"Could not determine language for formatter: {formatter_cmds}")

View file

@ -631,15 +631,16 @@ class FunctionImportedAsVisitor(ast.NodeVisitor):
def inject_async_profiling_into_existing_test(
test_string: str,
test_path: Path,
call_positions: list[CodePosition],
function_to_optimize: FunctionToOptimize,
tests_project_root: Path,
mode: TestingMode = TestingMode.BEHAVIOR,
test_path: Path | None = None,
) -> tuple[bool, str | None]:
"""Inject profiling for async function calls by setting environment variables before each call."""
test_code = test_string
with test_path.open(encoding="utf8") as f:
test_code = f.read()
try:
tree = ast.parse(test_code)
except SyntaxError:
@ -702,7 +703,6 @@ def detect_frameworks_from_code(code: str) -> dict[str, str]:
def inject_profiling_into_existing_test(
test_string: str,
test_path: Path,
call_positions: list[CodePosition],
function_to_optimize: FunctionToOptimize,
@ -712,17 +712,15 @@ def inject_profiling_into_existing_test(
tests_project_root = tests_project_root.resolve()
if function_to_optimize.is_async:
return inject_async_profiling_into_existing_test(
test_string=test_string,
call_positions=call_positions,
function_to_optimize=function_to_optimize,
tests_project_root=tests_project_root,
mode=mode.value,
test_path=test_path,
test_path, call_positions, function_to_optimize, tests_project_root, mode
)
used_frameworks = detect_frameworks_from_code(test_string)
with test_path.open(encoding="utf8") as f:
test_code = f.read()
used_frameworks = detect_frameworks_from_code(test_code)
try:
tree = ast.parse(test_string)
tree = ast.parse(test_code)
except SyntaxError:
logger.exception(f"Syntax error in code in file - {test_path}")
return False, None

View file

@ -31,7 +31,7 @@ class PrComment:
if name:
report_table[name] = counts
json_result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = {
result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = {
"optimization_explanation": self.optimization_explanation,
"best_runtime": humanize_runtime(self.best_runtime),
"original_runtime": humanize_runtime(self.original_runtime),
@ -45,10 +45,10 @@ class PrComment:
}
if self.original_async_throughput is not None and self.best_async_throughput is not None:
json_result["original_async_throughput"] = str(self.original_async_throughput)
json_result["best_async_throughput"] = str(self.best_async_throughput)
result["original_async_throughput"] = self.original_async_throughput
result["best_async_throughput"] = self.best_async_throughput
return json_result
return result
class FileDiffContent(BaseModel):

View file

@ -850,12 +850,11 @@ class LanguageSupport(Protocol):
def instrument_existing_test(
self,
test_string: str,
test_path: Path,
call_positions: Sequence[Any],
function_to_optimize: Any,
tests_project_root: Path,
mode: str,
test_path: Path | None,
) -> tuple[bool, str | None]:
"""Inject profiling code into an existing test file.

View file

@ -504,14 +504,14 @@ class JavaSupport(LanguageSupport):
def instrument_existing_test(
self,
test_string: str,
test_path: Path,
call_positions: Sequence[Any],
function_to_optimize: Any,
tests_project_root: Path,
mode: str,
test_path: Path | None,
) -> tuple[bool, str | None]:
"""Inject profiling code into an existing test file."""
test_string = test_path.read_text(encoding="utf-8")
return instrument_existing_test(
test_string=test_string, function_to_optimize=function_to_optimize, mode=mode, test_path=test_path
)

View file

@ -195,24 +195,30 @@ def normalize_codeflash_imports(source: str) -> str:
# Author: ali <mohammed18200118@gmail.com>
def inject_test_globals(generated_tests: GeneratedTestsList, test_framework: str = "jest") -> GeneratedTestsList:
def inject_test_globals(
generated_tests: GeneratedTestsList, test_framework: str = "jest", module_system: str = "esm"
) -> GeneratedTestsList:
# TODO: inside the prompt tell the llm if it should import jest functions or it's already injected in the global window
"""Inject test globals into all generated tests.
Args:
generated_tests: List of generated tests.
test_framework: The test framework being used ("jest", "vitest", or "mocha").
module_system: The module system ("esm" or "commonjs").
Returns:
Generated tests with test globals injected.
"""
# we only inject test globals for esm modules
is_cjs = module_system == "commonjs"
# Use vitest imports for vitest projects, jest imports for jest projects
if test_framework == "vitest":
global_import = "import { vi, describe, it, expect, beforeEach, afterEach, beforeAll, test } from 'vitest'\n"
elif test_framework == "mocha":
global_import = "import assert from 'node:assert/strict';\n"
if is_cjs:
global_import = "const assert = require('node:assert/strict');\n"
else:
global_import = "import assert from 'node:assert/strict';\n"
else:
# Default to jest imports for jest and other frameworks
global_import = (
@ -220,12 +226,283 @@ def inject_test_globals(generated_tests: GeneratedTestsList, test_framework: str
)
for test in generated_tests.generated_tests:
test.generated_original_test_source = global_import + test.generated_original_test_source
test.instrumented_behavior_test_source = global_import + test.instrumented_behavior_test_source
test.instrumented_perf_test_source = global_import + test.instrumented_perf_test_source
# Skip injection if the source already has the import (LLM may have included it)
if global_import.strip() not in test.generated_original_test_source:
test.generated_original_test_source = global_import + test.generated_original_test_source
if global_import.strip() not in test.instrumented_behavior_test_source:
test.instrumented_behavior_test_source = global_import + test.instrumented_behavior_test_source
if global_import.strip() not in test.instrumented_perf_test_source:
test.instrumented_perf_test_source = global_import + test.instrumented_perf_test_source
return generated_tests
_VITEST_IMPORT_RE = re.compile(r"^.*import\s+\{[^}]*\}\s+from\s+['\"]vitest['\"].*\n?", re.MULTILINE)
_VITEST_REQUIRE_RE = re.compile(
r"^.*(?:const|let|var)\s+\{[^}]*\}\s*=\s*require\s*\(\s*['\"]vitest['\"]\s*\).*\n?", re.MULTILINE
)
_JEST_GLOBALS_IMPORT_RE = re.compile(r"^.*import\s+\{[^}]*\}\s+from\s+['\"]@jest/globals['\"].*\n?", re.MULTILINE)
_JEST_GLOBALS_REQUIRE_RE = re.compile(
r"^.*(?:const|let|var)\s+\{[^}]*\}\s*=\s*require\s*\(\s*['\"]@jest/globals['\"]\s*\).*\n?", re.MULTILINE
)
_MOCHA_REQUIRE_RE = re.compile(
r"^.*(?:const|let|var)\s+\{[^}]*\}\s*=\s*require\s*\(\s*['\"]mocha['\"]\s*\).*\n?", re.MULTILINE
)
_VITEST_COMMENT_RE = re.compile(r"^.*//.*vitest imports.*\n?", re.MULTILINE | re.IGNORECASE)
# Chai import patterns — LLMs sometimes associate Mocha with Chai
_CHAI_IMPORT_RE = re.compile(r"^.*import\s+.*\s+from\s+['\"]chai['\"].*\n?", re.MULTILINE)
_CHAI_REQUIRE_RE = re.compile(r"^.*(?:const|let|var)\s+.*\s*=\s*require\s*\(\s*['\"]chai['\"]\s*\).*\n?", re.MULTILINE)
# Pattern to convert test() → it() — Mocha uses it(), not test()
_TEST_CALL_RE = re.compile(r"(\s*)test\s*\(")
def sanitize_mocha_imports(source: str) -> str:
"""Remove vitest/jest/mocha-require/chai imports from Mocha test source.
The AI service sometimes generates vitest or jest-style imports when the
framework is mocha. Mocha provides describe/it/before*/after* as globals,
so these imports must be removed. Also removes ``require('mocha')``
destructures since Mocha doesn't export those.
Additionally converts ``test()`` calls to ``it()`` since Mocha only
supports ``it()`` as its test function.
Args:
source: Generated test source code.
Returns:
Source with incorrect framework imports stripped and test() converted to it().
"""
source = _VITEST_IMPORT_RE.sub("", source)
source = _VITEST_REQUIRE_RE.sub("", source)
source = _JEST_GLOBALS_IMPORT_RE.sub("", source)
source = _JEST_GLOBALS_REQUIRE_RE.sub("", source)
source = _MOCHA_REQUIRE_RE.sub("", source)
source = _VITEST_COMMENT_RE.sub("", source)
source = _CHAI_IMPORT_RE.sub("", source)
source = _CHAI_REQUIRE_RE.sub("", source)
source = _TEST_CALL_RE.sub(r"\1it(", source)
return convert_expect_to_assert(source)
def _find_matching_paren(source: str, open_pos: int) -> int:
"""Find the position of the closing parenthesis matching the one at open_pos."""
depth = 0
in_string = False
string_char = None
i = open_pos
while i < len(source):
char = source[i]
if char in ('"', "'", "`") and (i == 0 or source[i - 1] != "\\"):
if not in_string:
in_string = True
string_char = char
elif char == string_char:
in_string = False
string_char = None
elif not in_string:
if char == "(":
depth += 1
elif char == ")":
depth -= 1
if depth == 0:
return i
i += 1
return -1
def convert_expect_to_assert(source: str) -> str:
"""Convert expect()-style assertions to node:assert/strict equivalents.
LLMs frequently generate Chai-style (``expect(x).to.equal(y)``) or
Jest-style (``expect(x).toBe(y)``) assertions for Mocha tests despite
being instructed to use ``assert``. This function converts the common
patterns to their ``node:assert/strict`` equivalents so that
instrumentation and Mocha execution work correctly.
Any ``expect()`` calls that cannot be converted are commented out with
``// SKIPPED`` to prevent ``ReferenceError: expect is not defined``.
Args:
source: Test source code that may contain expect() calls.
Returns:
Source with expect() calls converted to assert equivalents.
"""
if "expect(" not in source:
return source
lines = source.split("\n")
converted: list[str] = []
for line in lines:
converted_line = _convert_expect_line(line)
converted.append(converted_line)
return "\n".join(converted)
# Patterns mapping (chain_suffix → conversion_type)
# "simple" = assert.func(actual, value), "ok_cmp" = assert.ok(actual OP value)
# "ok_method" = assert.ok(actual.method(value)), "type" = assert.ok(typeof actual === ...)
# "truthy" = assert.ok(actual) / assert.strictEqual(actual, bool)
# "throws" = assert.throws, "noop" = assert.ok(actual !== undefined)
_EXPECT_CHAIN_MAP: list[tuple[str, str, str | None]] = [
# Jest patterns (most common)
(".toBe(", "simple_strictEqual", None),
(".toEqual(", "simple_deepStrictEqual", None),
(".toStrictEqual(", "simple_deepStrictEqual", None),
(".toBeGreaterThan(", "ok_gt", None),
(".toBeGreaterThanOrEqual(", "ok_gte", None),
(".toBeLessThan(", "ok_lt", None),
(".toBeLessThanOrEqual(", "ok_lte", None),
(".toContain(", "ok_includes", None),
(".toHaveLength(", "ok_length", None),
(".toBeNull(", "null_check", None),
(".toBeUndefined(", "undef_check", None),
(".toBeTruthy(", "truthy", None),
(".toBeFalsy(", "falsy", None),
(".toThrow(", "throws", None),
(".toMatch(", "ok_match", None),
# Chai .to. patterns
(".to.equal(", "simple_strictEqual", None),
(".to.eql(", "simple_deepStrictEqual", None),
(".to.deep.equal(", "simple_deepStrictEqual", None),
(".to.be.greaterThan(", "ok_gt", None),
(".to.be.lessThan(", "ok_lt", None),
(".to.be.above(", "ok_gt", None),
(".to.be.below(", "ok_lt", None),
(".to.be.at.least(", "ok_gte", None),
(".to.be.at.most(", "ok_lte", None),
(".to.include(", "ok_includes", None),
(".to.contain(", "ok_includes", None),
(".to.not.include(", "ok_not_includes", None),
(".to.not.contain(", "ok_not_includes", None),
(".to.have.length(", "ok_length", None),
(".to.have.lengthOf(", "ok_length", None),
(".to.throw(", "throws", None),
(".to.match(", "ok_match", None),
(".to.be.a(", "noop", None),
(".to.be.an(", "noop", None),
(".to.be.instanceOf(", "noop", None),
(".to.be.instanceof(", "noop", None),
(".to.exist", "truthy_no_arg", None),
(".to.not.exist", "falsy_no_arg", None),
(".to.be.true", "true_no_arg", None),
(".to.be.false", "false_no_arg", None),
(".to.be.null", "null_no_arg", None),
(".to.be.undefined", "undef_no_arg", None),
(".to.be.ok", "truthy_no_arg", None),
(".to.not.be.ok", "falsy_no_arg", None),
]
def _convert_expect_line(line: str) -> str:
"""Convert a single line containing expect() to an assert equivalent."""
stripped = line.lstrip()
if "expect(" not in stripped:
return line
indent = line[: len(line) - len(stripped)]
expect_idx = line.find("expect(")
if expect_idx == -1:
return line
open_paren = expect_idx + len("expect")
close_paren = _find_matching_paren(line, open_paren)
if close_paren == -1:
# Multi-line expect or malformed — comment out to prevent ReferenceError
return f"{indent}// SKIPPED (unconvertible expect): {stripped}"
actual_expr = line[open_paren + 1 : close_paren]
rest = line[close_paren + 1 :].strip()
trailing_semi = ";" if rest.endswith(";") else ""
# Try each chain pattern
for chain_prefix, conversion_type, _ in _EXPECT_CHAIN_MAP:
if not rest.startswith(chain_prefix):
continue
# No-argument chains (e.g. .to.be.true, .to.exist)
if conversion_type.endswith("_no_arg"):
return _convert_no_arg(indent, actual_expr, conversion_type, trailing_semi)
# Extract the argument inside the chain's parentheses
chain_open = rest.find("(")
if chain_open == -1:
break
chain_close = _find_matching_paren(rest, chain_open)
if chain_close == -1:
break
value_expr = rest[chain_open + 1 : chain_close]
return _convert_with_arg(indent, actual_expr, value_expr, conversion_type, trailing_semi)
# Fallback: comment out unconvertible expect() to prevent ReferenceError
return f"{indent}// SKIPPED (unconvertible expect): {stripped}"
def _convert_no_arg(indent: str, actual: str, conversion_type: str, semi: str) -> str:
"""Convert expect patterns that take no argument (e.g., .to.be.true)."""
if conversion_type == "true_no_arg":
return f"{indent}assert.strictEqual({actual}, true){semi}"
if conversion_type == "false_no_arg":
return f"{indent}assert.strictEqual({actual}, false){semi}"
if conversion_type == "null_no_arg":
return f"{indent}assert.strictEqual({actual}, null){semi}"
if conversion_type == "undef_no_arg":
return f"{indent}assert.strictEqual({actual}, undefined){semi}"
if conversion_type == "truthy_no_arg":
return f"{indent}assert.ok({actual}){semi}"
if conversion_type == "falsy_no_arg":
return f"{indent}assert.ok(!({actual})){semi}"
return f"{indent}assert.ok({actual} !== undefined){semi}"
def _convert_with_arg(indent: str, actual: str, value: str, conversion_type: str, semi: str) -> str:
"""Convert expect patterns that take an argument."""
if conversion_type == "simple_strictEqual":
return f"{indent}assert.strictEqual({actual}, {value}){semi}"
if conversion_type == "simple_deepStrictEqual":
return f"{indent}assert.deepStrictEqual({actual}, {value}){semi}"
if conversion_type == "ok_gt":
return f"{indent}assert.ok(({actual}) > ({value})){semi}"
if conversion_type == "ok_gte":
return f"{indent}assert.ok(({actual}) >= ({value})){semi}"
if conversion_type == "ok_lt":
return f"{indent}assert.ok(({actual}) < ({value})){semi}"
if conversion_type == "ok_lte":
return f"{indent}assert.ok(({actual}) <= ({value})){semi}"
if conversion_type == "ok_includes":
return f"{indent}assert.ok(String({actual}).includes({value})){semi}"
if conversion_type == "ok_not_includes":
return f"{indent}assert.ok(!String({actual}).includes({value})){semi}"
if conversion_type == "ok_length":
return f"{indent}assert.strictEqual(({actual}).length, {value}){semi}"
if conversion_type == "ok_match":
return f"{indent}assert.match(String({actual}), {value}){semi}"
if conversion_type == "null_check":
return f"{indent}assert.strictEqual({actual}, null){semi}"
if conversion_type == "undef_check":
return f"{indent}assert.strictEqual({actual}, undefined){semi}"
if conversion_type == "truthy":
return f"{indent}assert.ok({actual}){semi}"
if conversion_type == "falsy":
return f"{indent}assert.ok(!({actual})){semi}"
if conversion_type == "throws":
if value:
return f"{indent}assert.throws(() => {{ {actual}; }}, {value}){semi}"
return f"{indent}assert.throws(() => {{ {actual}; }}){semi}"
# noop: type checks like .to.be.a('string') — just verify defined
if conversion_type == "noop":
return f"{indent}assert.ok({actual} !== undefined){semi}"
return f"{indent}assert.ok({actual} !== undefined){semi}"
# Author: ali <mohammed18200118@gmail.com>
def disable_ts_check(generated_tests: GeneratedTestsList) -> GeneratedTestsList:
"""Disable TypeScript type checking in all generated tests.

View file

@ -151,7 +151,9 @@ class JavaScriptFunctionOptimizer(FunctionOptimizer):
)
candidate_sqlite.unlink(missing_ok=True)
else:
match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results)
match, diffs = compare_test_results(
baseline_results.behavior_test_results, candidate_behavior_results, pass_fail_only=True
)
return match, diffs
def should_skip_sqlite_cleanup(self, testing_type: TestingMode, optimization_iteration: int) -> bool:

View file

@ -755,12 +755,11 @@ def transform_expect_calls(
def inject_profiling_into_existing_js_test(
test_string: str,
test_path: Path,
call_positions: list[CodePosition],
function_to_optimize: FunctionToOptimize,
tests_project_root: Path,
mode: str = TestingMode.BEHAVIOR,
test_path: Path | None = None,
) -> tuple[bool, str | None]:
"""Inject profiling code into an existing JavaScript test file.
@ -768,7 +767,6 @@ def inject_profiling_into_existing_js_test(
to enable behavioral verification and performance benchmarking.
Args:
test_string: String contents of the test file.
test_path: Path to the test file.
call_positions: List of code positions where the function is called.
function_to_optimize: The function being optimized.
@ -779,7 +777,13 @@ def inject_profiling_into_existing_js_test(
Tuple of (success, instrumented_code).
"""
test_code = test_string
try:
with test_path.open(encoding="utf8") as f:
test_code = f.read()
except Exception as e:
logger.error(f"Failed to read test file {test_path}: {e}")
return False, None
# Get the relative path for test identification
try:
rel_path = test_path.relative_to(tests_project_root)

View file

@ -8,6 +8,7 @@ and converts the output to JUnit XML in Python, avoiding extra npm dependencies.
from __future__ import annotations
import json
import re
import subprocess
import time
from pathlib import Path
@ -86,7 +87,7 @@ def _ensure_runtime_files(project_root: Path) -> None:
logger.error(f"Could not install codeflash. Please install it manually: {' '.join(install_cmd)}")
def mocha_json_to_junit_xml(json_str: str, output_file: Path) -> None:
def mocha_json_to_junit_xml(json_str: str, output_file: Path, test_files: list[Path] | None = None) -> None:
"""Convert Mocha's JSON reporter output to JUnit XML.
Mocha JSON format:
@ -94,9 +95,16 @@ def mocha_json_to_junit_xml(json_str: str, output_file: Path) -> None:
Each test object has: fullTitle, title, duration, err, ...
Mocha's JSON reporter does NOT include a ``file`` field on test objects,
so we accept the known ``test_files`` list from the caller and set the
``file`` attribute on testcase/testsuite elements. This allows
``parse_jest_test_xml()`` to resolve the test file via its ``file``
attribute lookup path.
Args:
json_str: JSON string from Mocha's --reporter json output.
output_file: Path to write the JUnit XML file.
test_files: Optional list of test file paths that were passed to Mocha.
"""
try:
@ -125,11 +133,35 @@ def mocha_json_to_junit_xml(json_str: str, output_file: Path) -> None:
suite_name = suite_name or "root"
suites.setdefault(suite_name, []).append(test)
# Build a mapping from describe block names to file paths by reading test files.
# Each generated test file wraps tests in describe('functionName', ...) so we
# can map suite names back to their source file.
suite_to_file: dict[str, str] = {}
if test_files:
for tf in test_files:
suite_to_file[tf.name] = str(tf)
# Try to extract the top-level describe name from the file content
try:
content = tf.read_text(encoding="utf-8")
m = re.search(r"describe\s*\(\s*['\"]([^'\"]+)['\"]", content)
if m:
suite_to_file[m.group(1)] = str(tf)
except Exception:
pass
# Fallback: if we have test files, use the first one as default for any unmatched suites
default_file = str(test_files[0]) if test_files else ""
for suite_name, suite_tests in suites.items():
testsuite = SubElement(testsuites, "testsuite")
testsuite.set("name", suite_name)
testsuite.set("tests", str(len(suite_tests)))
# Resolve file path: try suite name match, then use default
resolved_file = suite_to_file.get(suite_name, default_file)
if resolved_file:
testsuite.set("file", resolved_file)
suite_failures = 0
suite_time = 0.0
@ -140,6 +172,9 @@ def mocha_json_to_junit_xml(json_str: str, output_file: Path) -> None:
duration_ms = test.get("duration", 0) or 0
duration_s = duration_ms / 1000.0
testcase.set("time", str(duration_s))
if resolved_file:
testcase.set("file", resolved_file)
suite_time += duration_s
err = test.get("err", {})
@ -292,6 +327,7 @@ def _run_mocha_and_convert(
result_file_path: Path,
subprocess_timeout: int,
label: str,
test_files: list[Path] | None = None,
) -> subprocess.CompletedProcess:
"""Run Mocha subprocess, extract JSON output, and convert to JUnit XML.
@ -302,6 +338,7 @@ def _run_mocha_and_convert(
result_file_path: Path to write JUnit XML.
subprocess_timeout: Timeout in seconds.
label: Label for log messages (e.g. "behavioral", "benchmarking").
test_files: Test file paths passed to Mocha (for file attribute in XML).
Returns:
CompletedProcess with combined stdout/stderr.
@ -343,7 +380,7 @@ def _run_mocha_and_convert(
if result.stdout:
mocha_json = _extract_mocha_json(result.stdout)
if mocha_json:
mocha_json_to_junit_xml(mocha_json, result_file_path)
mocha_json_to_junit_xml(mocha_json, result_file_path, test_files=test_files)
logger.debug(f"Converted Mocha JSON to JUnit XML: {result_file_path}")
else:
logger.warning(f"Could not extract Mocha JSON from stdout (len={len(result.stdout)})")
@ -414,6 +451,7 @@ def run_mocha_behavioral_tests(
result_file_path=result_file_path,
subprocess_timeout=subprocess_timeout,
label="behavioral",
test_files=test_files,
)
finally:
wall_clock_ns = time.perf_counter_ns() - start_time_ns
@ -515,6 +553,7 @@ def run_mocha_benchmarking_tests(
result_file_path=result_file_path,
subprocess_timeout=total_timeout,
label="benchmarking",
test_files=test_files,
)
finally:
wall_clock_seconds = time.time() - total_start_time
@ -589,6 +628,7 @@ def run_mocha_line_profile_tests(
result_file_path=result_file_path,
subprocess_timeout=subprocess_timeout,
label="line_profile",
test_files=test_files,
)
finally:
wall_clock_ns = time.perf_counter_ns() - start_time_ns

View file

@ -172,22 +172,12 @@ def parse_jest_test_xml(
if global_stdout:
marker_count = len(jest_start_pattern.findall(global_stdout))
if marker_count > 0:
logger.debug(f"Found {marker_count} timing start markers in Jest stdout")
else:
logger.debug(f"No timing start markers found in Jest stdout (len={len(global_stdout)})")
# Check for END markers with duration (perf test markers)
logger.debug(f"Found {marker_count} timing start markers in subprocess stdout")
end_marker_count = len(jest_end_pattern.findall(global_stdout))
if end_marker_count > 0:
logger.debug(
f"[PERF-DEBUG] Found {end_marker_count} END timing markers with duration in Jest stdout"
)
# Sample a few markers to verify loop indices
end_samples = list(jest_end_pattern.finditer(global_stdout))[:5]
for sample in end_samples:
groups = sample.groups()
logger.debug(f"[PERF-DEBUG] Sample END marker: loopIndex={groups[3]}, duration={groups[5]}")
logger.debug(f"Found {end_marker_count} END timing markers with duration in subprocess stdout")
else:
logger.debug("[PERF-DEBUG] No END markers with duration found in Jest stdout")
logger.debug(f"No END markers found in subprocess stdout (len={len(global_stdout)})")
except (AttributeError, UnicodeDecodeError):
global_stdout = ""
@ -215,6 +205,10 @@ def parse_jest_test_xml(
# Key: (testName, testName2, funcName, loopIndex, lineId)
key = match.groups()[:5]
end_matches_dict[key] = match
logger.debug(
f"Suite {suite_count}: combined_stdout len={len(combined_stdout)}, "
f"start_matches={len(start_matches)}, end_matches={len(end_matches_dict)}"
)
# Debug: log suite-level END marker parsing for perf tests
if end_matches_dict:
@ -371,6 +365,25 @@ def parse_jest_test_xml(
# end_key is (module, testName, funcName, loopIndex, invocationId)
if len(end_key) >= 2 and sanitized_test_name in end_key[1]:
matching_ends_direct.append(end_match)
# Fallback: If no matches found but END markers exist with "unknown" test name
# (happens in Vitest where beforeEach hook doesn't fire to set currentTestName),
# match ALL "unknown" markers to this testcase. Use a consumed set to avoid
# assigning the same marker to multiple testcases.
if not matching_ends_direct and end_matches_dict:
unknown_markers = [(k, m) for k, m in end_matches_dict.items() if len(k) >= 2 and k[1] == "unknown"]
if unknown_markers:
# Assign all unconsumed unknown markers to this testcase
for _, end_match in unknown_markers:
matching_ends_direct.append(end_match)
# Remove consumed markers so they aren't double-assigned to other testcases
for end_key, _ in unknown_markers:
end_matches_dict.pop(end_key, None)
logger.debug(
f"[PERF-UNKNOWN-MATCH] Testcase '{test_name[:40]}': matched {len(matching_ends_direct)} "
f"'unknown' END markers (Vitest fallback)"
)
# Debug: log matching results for perf tests
if matching_ends_direct:
loop_indices = [int(m.groups()[3]) if m.groups()[3].isdigit() else 1 for m in matching_ends_direct]

View file

@ -2256,25 +2256,20 @@ class JavaScriptSupport:
def instrument_existing_test(
self,
test_string: str,
test_path: Path,
call_positions: Sequence[Any],
function_to_optimize: Any,
tests_project_root: Path,
mode: str,
test_path: Path | None,
) -> tuple[bool, str | None]:
"""Inject profiling code into an existing JavaScript test file.
Wraps function calls with codeflash.capture() or codeflash.capturePerf()
for behavioral verification and performance benchmarking.
Args:
test_string: The test source code string.
test_path: Path to the test file.
call_positions: List of code positions where the function is called.
function_to_optimize: The function being optimized.
tests_project_root: Root directory of tests.
mode: Testing mode - "behavior" or "performance".
test_path: Path to the test file.
Returns:
Tuple of (success, instrumented_code).
@ -2282,6 +2277,7 @@ class JavaScriptSupport:
"""
from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test
test_string = test_path.read_text(encoding="utf-8")
return inject_profiling_into_existing_js_test(
test_string=test_string,
test_path=test_path,

View file

@ -228,6 +228,8 @@ export default mergeConfig(originalConfig, {{
test: {{
// Override include pattern to match all test files including generated ones
include: ['**/*.test.ts', '**/*.test.js', '**/*.test.tsx', '**/*.test.jsx'],
// Use forks pool so timing markers from process.stdout.write flow to parent stdout
pool: 'forks',
}},
}});
"""
@ -242,6 +244,8 @@ export default defineConfig({
include: ['**/*.test.ts', '**/*.test.js', '**/*.test.tsx', '**/*.test.jsx'],
// Exclude common non-test directories
exclude: ['**/node_modules/**', '**/dist/**'],
// Use forks pool so timing markers from process.stdout.write flow to parent stdout
pool: 'forks',
},
});
"""
@ -280,6 +284,7 @@ def _build_vitest_behavioral_command(
"--reporter=default",
"--reporter=junit",
"--no-file-parallelism", # Serial execution for deterministic timing
"--pool=forks", # Use child processes so timing markers flow to parent stdout
]
# For monorepos with restrictive vitest configs (e.g., include: test/**/*.test.ts),
@ -329,6 +334,7 @@ def _build_vitest_benchmarking_command(
"--reporter=default",
"--reporter=junit",
"--no-file-parallelism", # Serial execution for consistent benchmarking
"--pool=forks", # Use child processes so timing markers flow to parent stdout
]
# Use codeflash vitest config to override restrictive include patterns
@ -659,11 +665,6 @@ def run_vitest_benchmarking_tests(
)
else:
logger.debug(f"[VITEST-BENCH] No perf END markers found in stdout (len={len(result.stdout)})")
# Check if there are behavior END markers instead
behavior_end_pattern = re.compile(r"!######[^:]+:[^:]+:[^:]+:\d+:[^#]+######!")
behavior_matches = list(behavior_end_pattern.finditer(result.stdout))
if behavior_matches:
logger.debug(f"[VITEST-BENCH] Found {len(behavior_matches)} behavior END markers instead (no duration)")
return result_file_path, result
@ -722,6 +723,7 @@ def run_vitest_line_profile_tests(
"--reporter=default",
"--reporter=junit",
"--no-file-parallelism", # Serial execution for consistent line profiling
"--pool=forks", # Use child processes so timing markers flow to parent stdout
]
# Use codeflash vitest config to override restrictive include patterns

View file

@ -749,6 +749,7 @@ def detect_unused_helper_functions(
"""
# Skip this analysis for non-Python languages since we use Python's ast module
if current_language() != Language.PYTHON:
logger.debug("Skipping unused helper function detection for non-Python languages")
return []
if isinstance(optimized_code, CodeStringsMarkdown) and len(optimized_code.code_strings) > 0:

View file

@ -127,9 +127,9 @@ class PythonFunctionOptimizer(FunctionOptimizer):
def parse_line_profile_test_results(
self, line_profiler_output_file: Path | None
) -> tuple[TestResults | dict, CoverageData | None]:
if line_profiler_output_file is None:
return {"timings": {}, "unit": 0, "str_out": ""}, None
return self.language_support.parse_line_profile_results(line_profiler_output_file), None
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
return parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
def compare_candidate_results(
self,

View file

@ -918,22 +918,20 @@ class PythonSupport:
def instrument_existing_test(
self,
test_string: str,
test_path: Path,
call_positions: Sequence[Any],
function_to_optimize: Any,
tests_project_root: Path,
mode: str,
test_path: Path | None,
) -> tuple[bool, str | None]:
"""Inject profiling code into an existing Python test file.
Args:
test_string: The test file content as a string.
test_path: Path to the test file.
call_positions: List of code positions where the function is called.
function_to_optimize: The function being optimized.
tests_project_root: Root directory of tests.
mode: Testing mode - "behavior" or "performance".
test_path: Path to the test file.
Returns:
Tuple of (success, instrumented_code).
@ -945,7 +943,6 @@ class PythonSupport:
testing_mode = TestingMode.BEHAVIOR if mode == "behavior" else TestingMode.PERFORMANCE
return inject_profiling_into_existing_test(
test_string=test_string,
test_path=test_path,
call_positions=list(call_positions),
function_to_optimize=function_to_optimize,

View file

@ -330,11 +330,12 @@ class CodeStringsMarkdown(BaseModel):
dict[str, str]: Mapping from file path (as string) to code.
"""
if "file_to_path" in self._cache:
if self._cache.get("file_to_path") is not None:
return self._cache["file_to_path"]
result = {str(code_string.file_path): code_string.code for code_string in self.code_strings}
self._cache["file_to_path"] = result
return result
self._cache["file_to_path"] = {
str(code_string.file_path): code_string.code for code_string in self.code_strings
}
return self._cache["file_to_path"]
@staticmethod
def parse_markdown_code(markdown_code: str, expected_language: str = "python") -> CodeStringsMarkdown:
@ -663,7 +664,7 @@ class CoverageData:
from rich.tree import Tree
tree = Tree("Test Coverage Results")
tree.add(f"Main Function: {self.main_func_coverage.name}: {self.main_func_coverage.coverage:.2f}%")
tree.add(f"Main Function: {self.main_func_coverage.name}: {self.coverage:.2f}%")
if self.dependent_func_coverage:
tree.add(
f"Dependent Function: {self.dependent_func_coverage.name}: {self.dependent_func_coverage.coverage:.2f}%"

View file

@ -539,9 +539,6 @@ class FunctionOptimizer:
) -> tuple[TestResults | dict, CoverageData | None]:
return TestResults(test_results=[]), None
def fixup_generated_tests(self, generated_tests: GeneratedTestsList) -> GeneratedTestsList:
return generated_tests
# --- End hooks ---
def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]:
@ -644,8 +641,6 @@ class FunctionOptimizer:
source_file_path=self.function_to_optimize.file_path,
)
generated_tests = self.fixup_generated_tests(generated_tests)
logger.debug(f"[PIPELINE] Processing {count_tests} generated tests")
for i, generated_test in enumerate(generated_tests.generated_tests):
logger.debug(
@ -1601,27 +1596,24 @@ class FunctionOptimizer:
for (test_file, test_type), tests_in_file_list in test_file_invocation_positions.items():
path_obj_test_file = Path(test_file)
test_string = path_obj_test_file.read_text(encoding="utf-8")
# Use language-specific instrumentation
success, injected_behavior_test = self.language_support.instrument_existing_test(
test_string=test_string,
test_path=path_obj_test_file,
call_positions=[test.position for test in tests_in_file_list],
function_to_optimize=self.function_to_optimize,
tests_project_root=self.test_cfg.tests_project_rootdir,
mode="behavior",
test_path=path_obj_test_file,
)
if not success:
logger.debug(f"Failed to instrument test file {test_file} for behavior testing")
continue
success, injected_perf_test = self.language_support.instrument_existing_test(
test_string=test_string,
test_path=path_obj_test_file,
call_positions=[test.position for test in tests_in_file_list],
function_to_optimize=self.function_to_optimize,
tests_project_root=self.test_cfg.tests_project_rootdir,
mode="performance",
test_path=path_obj_test_file,
)
if not success:
logger.debug(f"Failed to instrument test file {test_file} for performance testing")

View file

@ -11,7 +11,6 @@ from codeflash.code_utils.config_consts import (
MIN_TESTCASE_PASSED_THRESHOLD,
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD,
)
from codeflash.models.models import CoverageStatus
from codeflash.models.test_type import TestType
if TYPE_CHECKING:
@ -205,19 +204,7 @@ def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | Origin
def coverage_critic(original_code_coverage: CoverageData | None) -> bool:
"""Check if the coverage meets the threshold.
Returns True when:
- Coverage data exists, was parsed successfully, and meets the threshold, OR
- No coverage data is available (skip the check for languages/projects without coverage support), OR
- Coverage data exists but was NOT_FOUND (e.g., JaCoCo report not generated in multi-module projects)
"""
"""Check if the coverage meets the threshold."""
if original_code_coverage:
# If coverage data was not found (e.g., JaCoCo report doesn't exist in multi-module projects),
# skip the coverage check instead of failing with 0% coverage
if original_code_coverage.status == CoverageStatus.NOT_FOUND:
return True
return original_code_coverage.coverage >= COVERAGE_THRESHOLD
# When no coverage data is available (e.g., JavaScript, Java multi-module projects),
# skip the coverage check and allow optimization to proceed
return True
return False

View file

@ -27,7 +27,9 @@ def safe_repr(obj: object) -> str:
return f"<repr failed: {type(e).__name__}: {e}>"
def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]:
def compare_test_results(
original_results: TestResults, candidate_results: TestResults, pass_fail_only: bool = False
) -> tuple[bool, list[TestDiff]]:
# This is meant to be only called with test results for the first loop index
if len(original_results) == 0 or len(candidate_results) == 0:
return False, [] # empty test results are not equal
@ -100,7 +102,9 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
)
)
elif not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj):
elif not pass_fail_only and not comparator(
original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj
):
test_diffs.append(
TestDiff(
scope=TestDiffScope.RETURN_VALUE,
@ -125,8 +129,10 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
)
except Exception as e:
logger.error(e)
elif (original_test_result.stdout and cdd_test_result.stdout) and not comparator(
original_test_result.stdout, cdd_test_result.stdout
elif (
not pass_fail_only
and (original_test_result.stdout and cdd_test_result.stdout)
and not comparator(original_test_result.stdout, cdd_test_result.stdout)
):
test_diffs.append(
TestDiff(

View file

@ -266,10 +266,96 @@ if (RANDOM_SEED !== 0) {
}
}
// Current test context (set by Jest hooks)
// Current test context (set by Jest hooks or Vitest worker)
let currentTestName = null;
let currentTestPath = null; // Test file path from Jest
/**
* Get the current test name from Vitest's internal worker API.
* Vitest doesn't inject beforeEach as a global, so the Jest-style hook doesn't fire.
* Instead, we query Vitest's worker directly for the current test name.
*
* In Vitest's fork pool, `__vitest_worker__.current` is the current task object
* with properties: name, fullName, fullTestName, suite, type, file, etc.
* Also, `expect.getState().currentTestName` works from within a test.
*
* @returns {string|null} The current test name, or null if not in Vitest
*/
function getVitestTestName() {
// Prefer expect.getState().currentTestName — returns full path including describe blocks
// e.g., "Performance tests > should return true for basic HTML tags"
// This matches what Jest's beforeEach hook would set.
try {
if (typeof expect !== 'undefined' && expect.getState) {
const state = expect.getState();
if (state?.currentTestName) {
return state.currentTestName;
}
}
} catch (e) {
// expect not available
}
// Fallback: Vitest worker API — worker.current.fullTestName includes describe path
try {
const worker = globalThis.__vitest_worker__;
if (worker?.current?.fullTestName) {
return worker.current.fullTestName;
}
if (worker?.current?.fullName) {
return worker.current.fullName;
}
if (worker?.current?.name) {
return worker.current.name;
}
} catch (e) {
// Not in Vitest context
}
return null;
}
/**
* Get the current test file path from Vitest's internal worker API.
* @returns {string|null} The current test file path, or null if not in Vitest
*/
function getVitestTestPath() {
try {
const worker = globalThis.__vitest_worker__;
if (worker?.filepath) {
return worker.filepath;
}
} catch (e) {
// Not in Vitest context
}
// Fallback: try expect.getState() for testPath
try {
if (typeof expect !== 'undefined' && expect.getState) {
const state = expect.getState();
if (state?.testPath) {
return state.testPath;
}
}
} catch (e) {
// expect not available
}
return null;
}
/**
* Get the effective test name, trying Jest hooks first, then Vitest API, then fallback.
* @returns {string} The current test name
*/
function getEffectiveTestName() {
return currentTestName || getVitestTestName() || 'unknown';
}
/**
* Get the effective test path, trying Jest hooks first, then Vitest API, then fallback.
* @returns {string|null} The current test file path
*/
function getEffectiveTestPath() {
return currentTestPath || getVitestTestPath() || null;
}
// Invocation counter map: tracks how many times each testId has been seen
// Key: testId (testModule:testClass:testFunction:lineId:loopIndex)
// Value: count (starts at 0, increments each time same key is seen)
@ -549,13 +635,14 @@ function capture(funcName, lineId, fn, ...args) {
// Get test context (raw values for SQLite storage)
// Use TEST_MODULE env var if set, otherwise derive from test file path
const effectiveTestPath = getEffectiveTestPath();
let testModulePath;
if (TEST_MODULE) {
testModulePath = TEST_MODULE;
} else if (currentTestPath) {
} else if (effectiveTestPath) {
// Get relative path from cwd and convert to module-style path
const path = require('path');
const relativePath = path.relative(process.cwd(), currentTestPath);
const relativePath = path.relative(process.cwd(), effectiveTestPath);
// Convert to Python module-style path (e.g., "tests/test_foo.test.js" -> "tests.test_foo.test")
// This matches what Jest's junit XML produces
testModulePath = relativePath
@ -564,10 +651,10 @@ function capture(funcName, lineId, fn, ...args) {
.replace(/\.test$/, '.test') // Keep .test suffix
.replace(/\//g, '.'); // Convert path separators to dots
} else {
testModulePath = currentTestName || 'unknown';
testModulePath = getEffectiveTestName();
}
const testClassName = null; // Jest doesn't use classes like Python
const testFunctionName = currentTestName || 'unknown';
const testFunctionName = getEffectiveTestName();
// Sanitized versions for stdout tags (avoid regex conflicts)
const safeModulePath = sanitizeTestId(testModulePath);
@ -583,8 +670,8 @@ function capture(funcName, lineId, fn, ...args) {
// Format stdout tag (matches Python format, uses sanitized names)
const testStdoutTag = `${safeModulePath}:${testClassName ? testClassName + '.' : ''}${safeTestFunctionName}:${funcName}:${LOOP_INDEX}:${invocationId}`;
// Print start tag
console.log(`!$######${testStdoutTag}######$!`);
// Print start tag (use process.stdout.write to bypass test framework console interception)
process.stdout.write(`!$######${testStdoutTag}######$!\n`);
// Timing with nanosecond precision
const startTime = getTimeNs();
@ -602,14 +689,14 @@ function capture(funcName, lineId, fn, ...args) {
const durationNs = getDurationNs(startTime, endTime);
recordResult(testModulePath, testClassName, testFunctionName, funcName, invocationId, args, resolved, null, durationNs);
// Print end tag (no duration for behavior mode)
console.log(`!######${testStdoutTag}######!`);
process.stdout.write(`!######${testStdoutTag}######!\n`);
return resolved;
},
(err) => {
const endTime = getTimeNs();
const durationNs = getDurationNs(startTime, endTime);
recordResult(testModulePath, testClassName, testFunctionName, funcName, invocationId, args, null, err, durationNs);
console.log(`!######${testStdoutTag}######!`);
process.stdout.write(`!######${testStdoutTag}######!\n`);
throw err;
}
);
@ -623,7 +710,7 @@ function capture(funcName, lineId, fn, ...args) {
recordResult(testModulePath, testClassName, testFunctionName, funcName, invocationId, args, returnValue, error, durationNs);
// Print end tag (no duration for behavior mode, matching Python)
console.log(`!######${testStdoutTag}######!`);
process.stdout.write(`!######${testStdoutTag}######!\n`);
if (error) throw error;
return returnValue;
@ -656,22 +743,24 @@ function capturePerf(funcName, lineId, fn, ...args) {
const shouldLoop = getPerfLoopCount() > 1 && !checkSharedTimeLimit();
// Get test context (computed once, reused across batch)
// Uses Vitest worker API as fallback when Jest-style beforeEach hook doesn't fire
const effectiveTestPath = getEffectiveTestPath();
let testModulePath;
if (TEST_MODULE) {
testModulePath = TEST_MODULE;
} else if (currentTestPath) {
} else if (effectiveTestPath) {
const path = require('path');
const relativePath = path.relative(process.cwd(), currentTestPath);
const relativePath = path.relative(process.cwd(), effectiveTestPath);
testModulePath = relativePath
.replace(/\\/g, '/')
.replace(/\.js$/, '')
.replace(/\.test$/, '.test')
.replace(/\//g, '.');
} else {
testModulePath = currentTestName || 'unknown';
testModulePath = getEffectiveTestName();
}
const testClassName = null;
const testFunctionName = currentTestName || 'unknown';
const testFunctionName = getEffectiveTestName();
const safeModulePath = sanitizeTestId(testModulePath);
const safeTestFunctionName = sanitizeTestId(testFunctionName);
@ -767,8 +856,8 @@ function capturePerf(funcName, lineId, fn, ...args) {
lastError = e;
}
// Print end tag with timing
console.log(`!######${testStdoutTag}:${durationNs}######!`);
// Print end tag with timing (use process.stdout.write to bypass test framework console interception)
process.stdout.write(`!######${testStdoutTag}:${durationNs}######!\n`);
// Update shared loop counter
sharedPerfState.totalLoopsCompleted++;
@ -808,7 +897,7 @@ function capturePerf(funcName, lineId, fn, ...args) {
* @private
*/
function _recordAsyncTiming(startTime, testStdoutTag, durationNs, runtimes) {
console.log(`!######${testStdoutTag}:${durationNs}######!`);
process.stdout.write(`!######${testStdoutTag}:${durationNs}######!\n`);
sharedPerfState.totalLoopsCompleted++;
if (durationNs > 0) {
runtimes.push(durationNs / 1000);

View file

@ -141,24 +141,22 @@ def run_codeflash_command(
def build_command(
cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path, benchmarks_root: pathlib.Path | None = None
) -> list[str]:
repo_root = pathlib.Path(__file__).parent.parent.parent
python_path = os.path.relpath(repo_root / "codeflash" / "main.py", cwd)
python_path = "../../../codeflash/main.py" if "code_directories" in str(cwd) else "../codeflash/main.py"
base_command = ["uv", "run", "--no-project", python_path, "--file", config.file_path, "--no-pr"]
if config.function_name:
base_command.extend(["--function", config.function_name])
# Check if config exists (pyproject.toml or codeflash.toml) - if so, don't override it
has_codeflash_config = (cwd / "codeflash.toml").exists()
if not has_codeflash_config:
pyproject_path = cwd / "pyproject.toml"
if pyproject_path.exists():
with contextlib.suppress(Exception), open(pyproject_path, "rb") as f:
pyproject_data = tomllib.load(f)
has_codeflash_config = "tool" in pyproject_data and "codeflash" in pyproject_data["tool"]
# Check if pyproject.toml exists with codeflash config - if so, don't override it
pyproject_path = cwd / "pyproject.toml"
has_codeflash_config = False
if pyproject_path.exists():
with contextlib.suppress(Exception), open(pyproject_path, "rb") as f:
pyproject_data = tomllib.load(f)
has_codeflash_config = "tool" in pyproject_data and "codeflash" in pyproject_data["tool"]
# Only pass --tests-root and --module-root if they're not configured in config files
# Only pass --tests-root and --module-root if they're not configured in pyproject.toml
if not has_codeflash_config:
base_command.extend(["--tests-root", str(test_root), "--module-root", str(cwd)])

View file

@ -116,7 +116,7 @@ def test_sort():
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=Path(fto_path))
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
code, test_path, [CodePosition(6, 13), CodePosition(10, 13)], func, project_root_path, mode=TestingMode.BEHAVIOR
test_path, [CodePosition(6, 13), CodePosition(10, 13)], func, project_root_path, mode=TestingMode.BEHAVIOR
)
os.chdir(original_cwd)
assert success
@ -287,7 +287,7 @@ def test_sort():
tmp_test_path.write_text(code, encoding="utf-8")
success, new_test = inject_profiling_into_existing_test(
code, tmp_test_path, [CodePosition(7, 13), CodePosition(12, 13)], fto, tmp_test_path.parent
tmp_test_path, [CodePosition(7, 13), CodePosition(12, 13)], fto, tmp_test_path.parent
)
assert success
assert new_test.replace('"', "'") == expected.format(
@ -557,7 +557,7 @@ def test_sort():
tmp_test_path.write_text(code, encoding="utf-8")
success, new_test = inject_profiling_into_existing_test(
code, tmp_test_path, [CodePosition(6, 13), CodePosition(10, 13)], fto, tmp_test_path.parent
tmp_test_path, [CodePosition(6, 13), CodePosition(10, 13)], fto, tmp_test_path.parent
)
assert success
assert new_test.replace('"', "'") == expected.format(
@ -728,7 +728,7 @@ def test_sort():
tmp_test_path.write_text(code, encoding="utf-8")
success, new_test = inject_profiling_into_existing_test(
code, tmp_test_path, [CodePosition(6, 13), CodePosition(10, 13)], fto, tmp_test_path.parent
tmp_test_path, [CodePosition(6, 13), CodePosition(10, 13)], fto, tmp_test_path.parent
)
assert success
assert new_test.replace('"', "'") == expected.format(

View file

@ -299,7 +299,7 @@ async def test_async_function():
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
success, instrumented_test_code = inject_profiling_into_existing_test(
async_test_code, test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, mode=TestingMode.BEHAVIOR
test_file, [CodePosition(8, 18), CodePosition(11, 19)], func, temp_dir, mode=TestingMode.BEHAVIOR
)
# For async functions, once source is decorated, test injection should fail
@ -362,7 +362,7 @@ async def test_async_function():
# Now test the full pipeline with source module path
success, instrumented_test_code = inject_profiling_into_existing_test(
async_test_code, test_file, [CodePosition(8, 18)], func, temp_dir, mode=TestingMode.PERFORMANCE
test_file, [CodePosition(8, 18)], func, temp_dir, mode=TestingMode.PERFORMANCE
)
# For async functions, once source is decorated, test injection should fail
@ -431,7 +431,7 @@ async def test_mixed_functions():
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
success, instrumented_test_code = inject_profiling_into_existing_test(
mixed_test_code, test_file, [CodePosition(8, 18), CodePosition(11, 19)], async_func, temp_dir, mode=TestingMode.BEHAVIOR
test_file, [CodePosition(8, 18), CodePosition(11, 19)], async_func, temp_dir, mode=TestingMode.BEHAVIOR
)
# Async functions should not be instrumented at the test level
@ -605,7 +605,7 @@ async def test_multiple_calls():
assert len(call_positions) == 4
success, instrumented_test_code = inject_profiling_into_existing_test(
test_code_multiple_calls, test_file, call_positions, func, temp_dir, mode=TestingMode.BEHAVIOR
test_file, call_positions, func, temp_dir, mode=TestingMode.BEHAVIOR
)
assert success

View file

@ -15,9 +15,8 @@ from codeflash.code_utils.instrument_existing_tests import (
FunctionImportedAsVisitor,
inject_profiling_into_existing_test,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import (
CodeOptimizationContext,
CodePosition,
@ -28,6 +27,7 @@ from codeflash.models.models import (
TestsInFile,
TestType,
)
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
codeflash_wrap_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
@ -194,7 +194,7 @@ import dill as pickle"""
run_cwd = Path(__file__).parent.parent.resolve()
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
code, Path(f.name), [CodePosition(9, 17), CodePosition(13, 17), CodePosition(17, 17)], func, Path(f.name).parent
Path(f.name), [CodePosition(9, 17), CodePosition(13, 17), CodePosition(17, 17)], func, Path(f.name).parent
)
os.chdir(original_cwd)
assert success
@ -293,7 +293,7 @@ def test_prepare_image_for_yolo():
run_cwd = Path(__file__).parent.parent.resolve()
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
code, Path(f.name), [CodePosition(10, 14)], func, Path(f.name).parent
Path(f.name), [CodePosition(10, 14)], func, Path(f.name).parent
)
os.chdir(original_cwd)
assert success
@ -398,7 +398,7 @@ def test_sort():
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
code, test_path, [CodePosition(8, 14), CodePosition(12, 14)], func, project_root_path, mode=TestingMode.BEHAVIOR
test_path, [CodePosition(8, 14), CodePosition(12, 14)], func, project_root_path, mode=TestingMode.BEHAVIOR
)
os.chdir(original_cwd)
assert success
@ -409,7 +409,7 @@ def test_sort():
).replace('"', "'")
success, new_perf_test = inject_profiling_into_existing_test(
code, test_path,
test_path,
[CodePosition(8, 14), CodePosition(12, 14)],
func,
project_root_path,
@ -650,11 +650,11 @@ def test_sort_parametrized(input, expected_output):
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
code, test_path, [CodePosition(14, 13)], func, project_root_path, mode=TestingMode.BEHAVIOR
test_path, [CodePosition(14, 13)], func, project_root_path, mode=TestingMode.BEHAVIOR
)
assert success
success, new_test_perf = inject_profiling_into_existing_test(
code, test_path, [CodePosition(14, 13)], func, project_root_path, mode=TestingMode.PERFORMANCE
test_path, [CodePosition(14, 13)], func, project_root_path, mode=TestingMode.PERFORMANCE
)
os.chdir(original_cwd)
@ -927,11 +927,11 @@ def test_sort_parametrized_loop(input, expected_output):
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
code, test_path, [CodePosition(15, 17)], func, project_root_path, mode=TestingMode.BEHAVIOR
test_path, [CodePosition(15, 17)], func, project_root_path, mode=TestingMode.BEHAVIOR
)
assert success
success, new_test_perf = inject_profiling_into_existing_test(
code, test_path, [CodePosition(15, 17)], func, project_root_path, mode=TestingMode.PERFORMANCE
test_path, [CodePosition(15, 17)], func, project_root_path, mode=TestingMode.PERFORMANCE
)
os.chdir(original_cwd)
@ -1287,11 +1287,11 @@ def test_sort():
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
os.chdir(str(run_cwd))
success, new_test_behavior = inject_profiling_into_existing_test(
code, test_path, [CodePosition(11, 17)], func, project_root_path, mode=TestingMode.BEHAVIOR
test_path, [CodePosition(11, 17)], func, project_root_path, mode=TestingMode.BEHAVIOR
)
assert success
success, new_test_perf = inject_profiling_into_existing_test(
code, test_path, [CodePosition(11, 17)], func, project_root_path, mode=TestingMode.PERFORMANCE
test_path, [CodePosition(11, 17)], func, project_root_path, mode=TestingMode.PERFORMANCE
)
os.chdir(original_cwd)
assert success
@ -1661,7 +1661,7 @@ class TestPigLatin(unittest.TestCase):
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
os.chdir(run_cwd)
success, new_test_behavior = inject_profiling_into_existing_test(
code, test_path,
test_path,
[CodePosition(9, 17), CodePosition(13, 17), CodePosition(17, 17)],
func,
project_root_path,
@ -1669,7 +1669,7 @@ class TestPigLatin(unittest.TestCase):
)
assert success
success, new_test_perf = inject_profiling_into_existing_test(
code, test_path,
test_path,
[CodePosition(9, 17), CodePosition(13, 17), CodePosition(17, 17)],
func,
project_root_path,
@ -1917,11 +1917,11 @@ import unittest
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
os.chdir(run_cwd)
success, new_test_behavior = inject_profiling_into_existing_test(
code, test_path, [CodePosition(16, 17)], func, project_root_path, mode=TestingMode.BEHAVIOR
test_path, [CodePosition(16, 17)], func, project_root_path, mode=TestingMode.BEHAVIOR
)
assert success
success, new_test_perf = inject_profiling_into_existing_test(
code, test_path, [CodePosition(16, 17)], func, project_root_path, mode=TestingMode.PERFORMANCE
test_path, [CodePosition(16, 17)], func, project_root_path, mode=TestingMode.PERFORMANCE
)
os.chdir(original_cwd)
@ -2177,11 +2177,11 @@ import unittest
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
os.chdir(run_cwd)
success, new_test_behavior = inject_profiling_into_existing_test(
code, test_path, [CodePosition(14, 21)], func, project_root_path, mode=TestingMode.BEHAVIOR
test_path, [CodePosition(14, 21)], func, project_root_path, mode=TestingMode.BEHAVIOR
)
assert success
success, new_test_perf = inject_profiling_into_existing_test(
code, test_path, [CodePosition(14, 21)], func, project_root_path, mode=TestingMode.PERFORMANCE
test_path, [CodePosition(14, 21)], func, project_root_path, mode=TestingMode.PERFORMANCE
)
os.chdir(original_cwd)
assert success
@ -2428,10 +2428,10 @@ import unittest
f = FunctionToOptimize(function_name="sorter", file_path=code_path, parents=[])
os.chdir(run_cwd)
success, new_test_behavior = inject_profiling_into_existing_test(
code, test_path, [CodePosition(17, 21)], f, project_root_path, mode=TestingMode.BEHAVIOR
test_path, [CodePosition(17, 21)], f, project_root_path, mode=TestingMode.BEHAVIOR
)
success, new_test_perf = inject_profiling_into_existing_test(
code, test_path, [CodePosition(17, 21)], f, project_root_path, mode=TestingMode.PERFORMANCE
test_path, [CodePosition(17, 21)], f, project_root_path, mode=TestingMode.PERFORMANCE
)
os.chdir(original_cwd)
assert success
@ -2734,7 +2734,7 @@ def test_class_name_A_function_name():
)
os.chdir(str(run_cwd))
success, new_test = inject_profiling_into_existing_test(
code, test_path, [CodePosition(4, 23)], func, project_root_path
test_path, [CodePosition(4, 23)], func, project_root_path
)
os.chdir(original_cwd)
finally:
@ -2811,7 +2811,7 @@ def test_common_tags_1():
os.chdir(str(run_cwd))
success, new_test = inject_profiling_into_existing_test(
code, test_path, [CodePosition(7, 11), CodePosition(11, 11)], func, project_root_path
test_path, [CodePosition(7, 11), CodePosition(11, 11)], func, project_root_path
)
os.chdir(original_cwd)
assert success
@ -2877,7 +2877,7 @@ def test_sort():
os.chdir(str(run_cwd))
success, new_test = inject_profiling_into_existing_test(
code, test_path, [CodePosition(7, 15)], func, project_root_path
test_path, [CodePosition(7, 15)], func, project_root_path
)
os.chdir(original_cwd)
assert success
@ -2960,7 +2960,7 @@ def test_sort():
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
code, test_path, [CodePosition(6, 26), CodePosition(10, 26)], function_to_optimize, project_root_path
test_path, [CodePosition(6, 26), CodePosition(10, 26)], function_to_optimize, project_root_path
)
os.chdir(original_cwd)
assert success
@ -3061,7 +3061,7 @@ def test_code_replacement10() -> None:
run_cwd = Path(__file__).parent.parent.resolve()
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
code, test_file_path, [CodePosition(22, 28), CodePosition(28, 28)], func, test_file_path.parent
test_file_path, [CodePosition(22, 28), CodePosition(28, 28)], func, test_file_path.parent
)
os.chdir(original_cwd)
assert success
@ -3119,7 +3119,7 @@ def test_sleepfunc_sequence_short(n, expected_total_sleep_time):
func = FunctionToOptimize(function_name="accurate_sleepfunc", parents=[], file_path=code_path)
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
code, test_path, [CodePosition(8, 13)], func, project_root_path, mode=TestingMode.PERFORMANCE
test_path, [CodePosition(8, 13)], func, project_root_path, mode=TestingMode.PERFORMANCE
)
os.chdir(original_cwd)
@ -3236,7 +3236,7 @@ import unittest
func = FunctionToOptimize(function_name="accurate_sleepfunc", parents=[], file_path=code_path)
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
code, test_path, [CodePosition(12, 17)], func, project_root_path, mode=TestingMode.PERFORMANCE
test_path, [CodePosition(12, 17)], func, project_root_path, mode=TestingMode.PERFORMANCE
)
os.chdir(original_cwd)

View file

@ -147,6 +147,108 @@ class TestMochaJsonToJunitXml:
assert "suite A" in suite_names
assert "suite B" in suite_names
def test_file_attribute_set_from_test_files(self):
"""When test_files are passed, the file attribute should be set on testcase elements."""
from codeflash.languages.javascript.mocha_runner import mocha_json_to_junit_xml
mocha_json = json.dumps(
{
"stats": {"tests": 2, "passes": 2, "failures": 0, "duration": 50},
"tests": [
{"title": "test1", "fullTitle": "escapeHtml test1", "duration": 10, "err": {}},
{"title": "test2", "fullTitle": "escapeHtml test2", "duration": 20, "err": {}},
],
"passes": [],
"failures": [],
"pending": [],
}
)
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
# Create a test file whose describe block matches the suite name
test_file = tmpdir_path / "test_escapeHtml__unit_test_0.test.js"
test_file.write_text(
"const assert = require('node:assert/strict');\n"
"const escapeHtml = require('../index.js');\n"
"describe('escapeHtml', () => {\n"
" it('test1', () => { assert.ok(true); });\n"
" it('test2', () => { assert.ok(true); });\n"
"});\n",
encoding="utf-8",
)
output_file = tmpdir_path / "results.xml"
mocha_json_to_junit_xml(mocha_json, output_file, test_files=[test_file])
# Parse the XML and verify file attributes
import xml.etree.ElementTree as ET
tree = ET.parse(output_file)
root = tree.getroot()
testcases = root.findall(".//testcase")
assert len(testcases) == 2
for tc in testcases:
assert tc.get("file") == str(test_file)
def test_file_attribute_uses_default_when_no_describe_match(self):
"""When describe name doesn't match, the default (first) test file should be used."""
from codeflash.languages.javascript.mocha_runner import mocha_json_to_junit_xml
mocha_json = json.dumps(
{
"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 10},
"tests": [
{"title": "test1", "fullTitle": "someOtherSuite test1", "duration": 10, "err": {}},
],
"passes": [],
"failures": [],
"pending": [],
}
)
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
test_file = tmpdir_path / "test.test.js"
test_file.write_text("// no describe block", encoding="utf-8")
output_file = tmpdir_path / "results.xml"
mocha_json_to_junit_xml(mocha_json, output_file, test_files=[test_file])
import xml.etree.ElementTree as ET
tree = ET.parse(output_file)
testcases = tree.getroot().findall(".//testcase")
assert len(testcases) == 1
assert testcases[0].get("file") == str(test_file)
def test_no_file_attribute_when_no_test_files(self):
"""When test_files is not passed, no file attribute should be set."""
from codeflash.languages.javascript.mocha_runner import mocha_json_to_junit_xml
mocha_json = json.dumps(
{
"stats": {"tests": 1, "passes": 1, "failures": 0, "duration": 10},
"tests": [
{"title": "test1", "fullTitle": "suite test1", "duration": 10, "err": {}},
],
"passes": [],
"failures": [],
"pending": [],
}
)
with tempfile.TemporaryDirectory() as tmpdir:
output_file = Path(tmpdir) / "results.xml"
mocha_json_to_junit_xml(mocha_json, output_file)
import xml.etree.ElementTree as ET
tree = ET.parse(output_file)
testcases = tree.getroot().findall(".//testcase")
assert len(testcases) == 1
assert testcases[0].get("file") is None
class TestExtractMochaJson:
"""Tests for extracting Mocha JSON from mixed stdout."""
@ -456,6 +558,166 @@ class TestRunMochaBenchmarkingTests:
assert env.get("CODEFLASH_PERF_STABILITY_CHECK") == "false"
class TestSanitizeMochaImports:
"""Tests for stripping wrong framework imports from Mocha tests."""
def test_strips_vitest_import(self):
from codeflash.languages.javascript.edit_tests import sanitize_mocha_imports
source = "import { describe, test, expect, vi } from 'vitest'\nconst x = 1;\n"
result = sanitize_mocha_imports(source)
assert "vitest" not in result
assert "const x = 1;" in result
def test_strips_jest_globals_import(self):
from codeflash.languages.javascript.edit_tests import sanitize_mocha_imports
source = "import { jest, describe, it, expect } from '@jest/globals'\nconst x = 1;\n"
result = sanitize_mocha_imports(source)
assert "@jest/globals" not in result
assert "const x = 1;" in result
def test_strips_mocha_require(self):
from codeflash.languages.javascript.edit_tests import sanitize_mocha_imports
source = "const { describe, it, expect } = require('mocha');\nconst x = 1;\n"
result = sanitize_mocha_imports(source)
assert "require('mocha')" not in result
assert "const x = 1;" in result
def test_strips_vitest_comment(self):
from codeflash.languages.javascript.edit_tests import sanitize_mocha_imports
source = "// vitest imports (REQUIRED for vitest)\nimport { describe } from 'vitest'\nconst x = 1;\n"
result = sanitize_mocha_imports(source)
assert "vitest" not in result
assert "const x = 1;" in result
def test_strips_vitest_require_cjs(self):
from codeflash.languages.javascript.edit_tests import sanitize_mocha_imports
source = "const { describe, test, expect, vi, beforeEach, afterEach } = require('vitest');\nconst x = 1;\n"
result = sanitize_mocha_imports(source)
assert "vitest" not in result
assert "const x = 1;" in result
def test_strips_jest_globals_require_cjs(self):
from codeflash.languages.javascript.edit_tests import sanitize_mocha_imports
source = "const { jest, describe, it } = require('@jest/globals');\nconst x = 1;\n"
result = sanitize_mocha_imports(source)
assert "@jest/globals" not in result
assert "const x = 1;" in result
def test_strips_vitest_comment_and_cjs_require(self):
from codeflash.languages.javascript.edit_tests import sanitize_mocha_imports
source = "// vitest imports (REQUIRED for vitest - globals are NOT enabled by default)\nconst { describe, test, expect, vi, beforeEach, afterEach } = require('vitest');\nconst { setCharset } = require('../lib/utils');\n"
result = sanitize_mocha_imports(source)
assert "vitest" not in result
assert "require('../lib/utils')" in result
def test_preserves_unrelated_imports(self):
from codeflash.languages.javascript.edit_tests import sanitize_mocha_imports
source = "const sinon = require('sinon');\nconst assert = require('node:assert/strict');\n"
result = sanitize_mocha_imports(source)
assert "sinon" in result
assert "node:assert/strict" in result
class TestInjectTestGlobalsModuleSystem:
"""Tests for inject_test_globals with different module systems."""
def test_mocha_esm_uses_import(self):
from codeflash.languages.javascript.edit_tests import inject_test_globals
from codeflash.models.models import GeneratedTests, GeneratedTestsList
tests = GeneratedTestsList(
generated_tests=[
GeneratedTests(
generated_original_test_source="describe('test', () => {});",
instrumented_behavior_test_source="describe('test', () => {});",
instrumented_perf_test_source="describe('test', () => {});",
behavior_file_path=Path("test.test.js"),
perf_file_path=Path("test.perf.test.js"),
)
]
)
result = inject_test_globals(tests, test_framework="mocha", module_system="esm")
assert "import assert from 'node:assert/strict'" in result.generated_tests[0].generated_original_test_source
def test_mocha_cjs_uses_require(self):
from codeflash.languages.javascript.edit_tests import inject_test_globals
from codeflash.models.models import GeneratedTests, GeneratedTestsList
tests = GeneratedTestsList(
generated_tests=[
GeneratedTests(
generated_original_test_source="describe('test', () => {});",
instrumented_behavior_test_source="describe('test', () => {});",
instrumented_perf_test_source="describe('test', () => {});",
behavior_file_path=Path("test.test.js"),
perf_file_path=Path("test.perf.test.js"),
)
]
)
result = inject_test_globals(tests, test_framework="mocha", module_system="commonjs")
src = result.generated_tests[0].generated_original_test_source
assert "const assert = require('node:assert/strict')" in src
assert "import assert" not in src
def test_vitest_always_uses_import(self):
from codeflash.languages.javascript.edit_tests import inject_test_globals
from codeflash.models.models import GeneratedTests, GeneratedTestsList
tests = GeneratedTestsList(
generated_tests=[
GeneratedTests(
generated_original_test_source="describe('test', () => {});",
instrumented_behavior_test_source="describe('test', () => {});",
instrumented_perf_test_source="describe('test', () => {});",
behavior_file_path=Path("test.test.js"),
perf_file_path=Path("test.perf.test.js"),
)
]
)
result = inject_test_globals(tests, test_framework="vitest", module_system="commonjs")
assert "from 'vitest'" in result.generated_tests[0].generated_original_test_source
class TestEnsureModuleSystemCompatibilityMixed:
"""Tests for ensure_module_system_compatibility with mixed ESM+CJS code."""
def test_converts_imports_in_mixed_code_to_cjs(self):
from codeflash.languages.javascript.module_system import ensure_module_system_compatibility
# Code with both import (from inject_test_globals) and require (from backend)
code = "import assert from 'node:assert/strict';\nconst { foo } = require('./module');\n"
result = ensure_module_system_compatibility(code, "commonjs")
assert "require('node:assert/strict')" in result
assert "import assert" not in result
def test_converts_require_in_mixed_code_to_esm(self):
from codeflash.languages.javascript.module_system import ensure_module_system_compatibility
code = "import { describe } from 'vitest';\nconst foo = require('./module');\n"
result = ensure_module_system_compatibility(code, "esm")
assert "require" not in result
assert "import" in result
def test_pure_esm_to_cjs(self):
from codeflash.languages.javascript.module_system import ensure_module_system_compatibility
code = "import assert from 'node:assert/strict';\nimport { foo } from './module';\n"
result = ensure_module_system_compatibility(code, "commonjs")
assert "require('node:assert/strict')" in result
assert "import" not in result
class TestRunMochaLineProfileTests:
"""Tests for running Mocha line profile tests with mocked subprocess."""
@ -500,3 +762,70 @@ class TestRunMochaLineProfileTests:
env = call_kwargs.kwargs.get("env") or call_kwargs[1].get("env", {})
assert env.get("CODEFLASH_MODE") == "line_profile"
assert env.get("CODEFLASH_LINE_PROFILE_OUTPUT") == str(profile_output)
class TestParserUnknownTestNameFallback:
"""Tests for the parser's fallback when perf markers have 'unknown' test name."""
def test_unknown_markers_matched_to_first_testcase(self):
"""When capturePerf markers have 'unknown' test name (Vitest beforeEach not firing),
the parser should still match them to testcases via the fallback logic."""
from codeflash.languages.javascript.parse import parse_jest_test_xml
from codeflash.models.models import TestFile, TestFiles
from codeflash.models.test_type import TestType
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir_path = Path(tmpdir)
# Create a JUnit XML with one test suite and one testcase
xml_content = """<?xml version="1.0" encoding="UTF-8"?>
<testsuites>
<testsuite name="src/test_func__perf_test_0.test.ts" tests="1" failures="0" time="10.5">
<testcase name="should compute correctly" classname="src/test_func__perf_test_0.test.ts" time="10.5">
</testcase>
</testsuite>
</testsuites>"""
xml_path = tmpdir_path / "results.xml"
xml_path.write_text(xml_content, encoding="utf-8")
# Create test files
test_file = tmpdir_path / "test_func__perf_test_0.test.ts"
test_file.write_text("// perf test", encoding="utf-8")
test_files = TestFiles(
test_files=[
TestFile(
instrumented_behavior_file_path=test_file,
benchmarking_file_path=test_file,
test_type=TestType.GENERATED_REGRESSION,
)
]
)
# Create a mock subprocess result with perf markers using "unknown" test name
# This simulates what happens when Vitest's beforeEach doesn't fire
markers = []
for i in range(1, 6):
markers.append(f"!######test_mod:unknown:computeFunc:{i}:1_0:{1000 + i * 100}######!")
stdout = "\n".join(markers)
mock_result = MagicMock()
mock_result.stdout = stdout
test_config = MagicMock()
test_config.tests_project_rootdir = tmpdir_path
test_config.test_framework = "vitest"
results = parse_jest_test_xml(
test_xml_file_path=xml_path,
test_files=test_files,
test_config=test_config,
run_result=mock_result,
)
# The "unknown" fallback should assign all 5 markers to the testcase
assert len(results.test_results) == 5
# Verify runtimes were extracted (not the 10.5s XML fallback)
runtimes = [r.runtime for r in results.test_results if r.runtime is not None]
assert len(runtimes) == 5
assert all(r < 100_000 for r in runtimes) # All under 100 microseconds (nanoseconds)