mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge branch 'main' into fix/js-jest30-loop-runner
This commit is contained in:
commit
ce13a6d534
9 changed files with 473 additions and 6 deletions
|
|
@ -901,6 +901,115 @@ def validate_and_fix_import_style(test_code: str, source_file_path: Path, functi
|
|||
return test_code
|
||||
|
||||
|
||||
def fix_import_path_for_test_location(
|
||||
test_code: str, source_file_path: Path, test_file_path: Path, module_root: Path
|
||||
) -> str:
|
||||
"""Fix import paths in generated test code to be relative to test file location.
|
||||
|
||||
The AI may generate tests with import paths that are relative to the module root
|
||||
(e.g., 'apps/web/app/file') instead of relative to where the test file is located
|
||||
(e.g., '../../app/file'). This function fixes such imports.
|
||||
|
||||
Args:
|
||||
test_code: The generated test code.
|
||||
source_file_path: Absolute path to the source file being tested.
|
||||
test_file_path: Absolute path to where the test file will be written.
|
||||
module_root: Root directory of the module/project.
|
||||
|
||||
Returns:
|
||||
Test code with corrected import paths.
|
||||
|
||||
"""
|
||||
import os
|
||||
|
||||
# Calculate the correct relative import path from test file to source file
|
||||
test_dir = test_file_path.parent
|
||||
try:
|
||||
correct_rel_path = os.path.relpath(source_file_path, test_dir)
|
||||
correct_rel_path = correct_rel_path.replace("\\", "/")
|
||||
# Remove file extension for JS/TS imports
|
||||
for ext in [".tsx", ".ts", ".jsx", ".js", ".mjs", ".cjs"]:
|
||||
if correct_rel_path.endswith(ext):
|
||||
correct_rel_path = correct_rel_path[: -len(ext)]
|
||||
break
|
||||
# Ensure it starts with ./ or ../
|
||||
if not correct_rel_path.startswith("."):
|
||||
correct_rel_path = "./" + correct_rel_path
|
||||
except ValueError:
|
||||
# Can't compute relative path (different drives on Windows)
|
||||
return test_code
|
||||
|
||||
# Try to compute what incorrect path the AI might have generated
|
||||
# The AI often uses module_root-relative paths like 'apps/web/app/...'
|
||||
try:
|
||||
source_rel_to_module = os.path.relpath(source_file_path, module_root)
|
||||
source_rel_to_module = source_rel_to_module.replace("\\", "/")
|
||||
# Remove extension
|
||||
for ext in [".tsx", ".ts", ".jsx", ".js", ".mjs", ".cjs"]:
|
||||
if source_rel_to_module.endswith(ext):
|
||||
source_rel_to_module = source_rel_to_module[: -len(ext)]
|
||||
break
|
||||
except ValueError:
|
||||
return test_code
|
||||
|
||||
# Also check for project root-relative paths (including module_root in path)
|
||||
try:
|
||||
project_root = module_root.parent if module_root.name in ["src", "lib", "app", "web", "apps"] else module_root
|
||||
source_rel_to_project = os.path.relpath(source_file_path, project_root)
|
||||
source_rel_to_project = source_rel_to_project.replace("\\", "/")
|
||||
for ext in [".tsx", ".ts", ".jsx", ".js", ".mjs", ".cjs"]:
|
||||
if source_rel_to_project.endswith(ext):
|
||||
source_rel_to_project = source_rel_to_project[: -len(ext)]
|
||||
break
|
||||
except ValueError:
|
||||
source_rel_to_project = None
|
||||
|
||||
# Source file name (for matching module paths that end with the file name)
|
||||
source_name = source_file_path.stem
|
||||
|
||||
# Patterns to find import statements
|
||||
# ESM: import { func } from 'path' or import func from 'path'
|
||||
esm_import_pattern = re.compile(r"(import\s+(?:{[^}]+}|\w+)\s+from\s+['\"])([^'\"]+)(['\"])")
|
||||
# CommonJS: const { func } = require('path') or const func = require('path')
|
||||
cjs_require_pattern = re.compile(
|
||||
r"((?:const|let|var)\s+(?:{[^}]+}|\w+)\s*=\s*require\s*\(\s*['\"])([^'\"]+)(['\"])"
|
||||
)
|
||||
|
||||
def should_fix_path(import_path: str) -> bool:
|
||||
"""Check if this import path looks like it should point to our source file."""
|
||||
# Skip relative imports that already look correct
|
||||
if import_path.startswith(("./", "../")):
|
||||
return False
|
||||
# Skip package imports (no path separators or start with @)
|
||||
if "/" not in import_path and "\\" not in import_path:
|
||||
return False
|
||||
if import_path.startswith("@") and "/" in import_path:
|
||||
# Could be an alias like @/utils - skip these
|
||||
return False
|
||||
# Check if it looks like it points to our source file
|
||||
if import_path == source_rel_to_module:
|
||||
return True
|
||||
if source_rel_to_project and import_path == source_rel_to_project:
|
||||
return True
|
||||
if import_path.endswith((source_name, "/" + source_name)):
|
||||
return True
|
||||
return False
|
||||
|
||||
def fix_import(match: re.Match[str]) -> str:
|
||||
"""Replace incorrect import path with correct relative path."""
|
||||
prefix = match.group(1)
|
||||
import_path = match.group(2)
|
||||
suffix = match.group(3)
|
||||
|
||||
if should_fix_path(import_path):
|
||||
logger.debug(f"Fixing import path: {import_path} -> {correct_rel_path}")
|
||||
return f"{prefix}{correct_rel_path}{suffix}"
|
||||
return match.group(0)
|
||||
|
||||
test_code = esm_import_pattern.sub(fix_import, test_code)
|
||||
return cjs_require_pattern.sub(fix_import, test_code)
|
||||
|
||||
|
||||
def get_instrumented_test_path(original_path: Path, mode: str) -> Path:
|
||||
"""Generate path for instrumented test file.
|
||||
|
||||
|
|
|
|||
|
|
@ -175,6 +175,19 @@ def parse_jest_test_xml(
|
|||
logger.debug(f"Found {marker_count} timing start markers in Jest stdout")
|
||||
else:
|
||||
logger.debug(f"No timing start markers found in Jest stdout (len={len(global_stdout)})")
|
||||
# Check for END markers with duration (perf test markers)
|
||||
end_marker_count = len(jest_end_pattern.findall(global_stdout))
|
||||
if end_marker_count > 0:
|
||||
logger.debug(
|
||||
f"[PERF-DEBUG] Found {end_marker_count} END timing markers with duration in Jest stdout"
|
||||
)
|
||||
# Sample a few markers to verify loop indices
|
||||
end_samples = list(jest_end_pattern.finditer(global_stdout))[:5]
|
||||
for sample in end_samples:
|
||||
groups = sample.groups()
|
||||
logger.debug(f"[PERF-DEBUG] Sample END marker: loopIndex={groups[3]}, duration={groups[5]}")
|
||||
else:
|
||||
logger.debug("[PERF-DEBUG] No END markers with duration found in Jest stdout")
|
||||
except (AttributeError, UnicodeDecodeError):
|
||||
global_stdout = ""
|
||||
|
||||
|
|
@ -197,6 +210,14 @@ def parse_jest_test_xml(
|
|||
key = match.groups()[:5]
|
||||
end_matches_dict[key] = match
|
||||
|
||||
# Debug: log suite-level END marker parsing for perf tests
|
||||
if end_matches_dict:
|
||||
# Get unique loop indices from the parsed END markers
|
||||
loop_indices = sorted({int(k[3]) if k[3].isdigit() else 1 for k in end_matches_dict})
|
||||
logger.debug(
|
||||
f"[PERF-DEBUG] Suite {suite_count}: parsed {len(end_matches_dict)} END markers from suite_stdout, loop_index range: {min(loop_indices)}-{max(loop_indices)}"
|
||||
)
|
||||
|
||||
# Also collect timing markers from testcase-level system-out (Vitest puts output at testcase level)
|
||||
for tc in suite:
|
||||
tc_system_out = tc._elem.find("system-out") # noqa: SLF001
|
||||
|
|
@ -327,6 +348,13 @@ def parse_jest_test_xml(
|
|||
sanitized_test_name = re.sub(r"[!#: ()\[\]{}|\\/*?^$.+\-]", "_", test_name)
|
||||
matching_starts = [m for m in start_matches if sanitized_test_name in m.group(2)]
|
||||
|
||||
# Debug: log which branch we're taking
|
||||
logger.debug(
|
||||
f"[FLOW-DEBUG] Testcase '{test_name[:50]}': "
|
||||
f"total_start_matches={len(start_matches)}, matching_starts={len(matching_starts)}, "
|
||||
f"total_end_matches={len(end_matches_dict)}"
|
||||
)
|
||||
|
||||
# For performance tests (capturePerf), there are no START markers - only END markers with duration
|
||||
# Check for END markers directly if no START markers found
|
||||
matching_ends_direct = []
|
||||
|
|
@ -337,6 +365,28 @@ def parse_jest_test_xml(
|
|||
# end_key is (module, testName, funcName, loopIndex, invocationId)
|
||||
if len(end_key) >= 2 and sanitized_test_name in end_key[1]:
|
||||
matching_ends_direct.append(end_match)
|
||||
# Debug: log matching results for perf tests
|
||||
if matching_ends_direct:
|
||||
loop_indices = [int(m.groups()[3]) if m.groups()[3].isdigit() else 1 for m in matching_ends_direct]
|
||||
logger.debug(
|
||||
f"[PERF-MATCH] Testcase '{test_name[:40]}': matched {len(matching_ends_direct)} END markers, "
|
||||
f"loop_index range: {min(loop_indices)}-{max(loop_indices)}"
|
||||
)
|
||||
elif end_matches_dict:
|
||||
# No matches but we have END markers - check why
|
||||
sample_keys = list(end_matches_dict.keys())[:3]
|
||||
logger.debug(
|
||||
f"[PERF-MISMATCH] Testcase '{test_name[:40]}': no matches found. "
|
||||
f"sanitized_test_name='{sanitized_test_name[:50]}', "
|
||||
f"sample end_keys={[k[1][:30] if len(k) >= 2 else k for k in sample_keys]}"
|
||||
)
|
||||
|
||||
# Log if we're skipping the matching_ends_direct branch
|
||||
if matching_starts and end_matches_dict:
|
||||
logger.debug(
|
||||
f"[FLOW-SKIP] Testcase '{test_name[:40]}': has {len(matching_starts)} START markers, "
|
||||
f"skipping {len(end_matches_dict)} END markers (behavior test mode)"
|
||||
)
|
||||
|
||||
if not matching_starts and not matching_ends_direct:
|
||||
# No timing markers found - use JUnit XML time attribute as fallback
|
||||
|
|
@ -373,11 +423,13 @@ def parse_jest_test_xml(
|
|||
)
|
||||
elif matching_ends_direct:
|
||||
# Performance test format: process END markers directly (no START markers)
|
||||
loop_indices_found = []
|
||||
for end_match in matching_ends_direct:
|
||||
groups = end_match.groups()
|
||||
# groups: (module, testName, funcName, loopIndex, invocationId, durationNs)
|
||||
func_name = groups[2]
|
||||
loop_index = int(groups[3]) if groups[3].isdigit() else 1
|
||||
loop_indices_found.append(loop_index)
|
||||
line_id = groups[4]
|
||||
try:
|
||||
runtime = int(groups[5])
|
||||
|
|
@ -403,6 +455,12 @@ def parse_jest_test_xml(
|
|||
stdout="",
|
||||
)
|
||||
)
|
||||
if loop_indices_found:
|
||||
logger.debug(
|
||||
f"[LOOP-DEBUG] Testcase '{test_name}': processed {len(matching_ends_direct)} END markers, "
|
||||
f"loop_index range: {min(loop_indices_found)}-{max(loop_indices_found)}, "
|
||||
f"total results so far: {len(test_results.test_results)}"
|
||||
)
|
||||
else:
|
||||
# Process each timing marker
|
||||
for match in matching_starts:
|
||||
|
|
@ -454,5 +512,19 @@ def parse_jest_test_xml(
|
|||
f"Jest XML parsing complete: {len(test_results.test_results)} results "
|
||||
f"from {suite_count} suites, {testcase_count} testcases"
|
||||
)
|
||||
# Debug: show loop_index distribution for perf analysis
|
||||
if test_results.test_results:
|
||||
loop_indices = [r.loop_index for r in test_results.test_results]
|
||||
unique_loop_indices = sorted(set(loop_indices))
|
||||
min_idx, max_idx = min(unique_loop_indices), max(unique_loop_indices)
|
||||
logger.debug(
|
||||
f"[LOOP-SUMMARY] Results loop_index: min={min_idx}, max={max_idx}, "
|
||||
f"unique_count={len(unique_loop_indices)}, total_results={len(loop_indices)}"
|
||||
)
|
||||
if max_idx == 1 and len(loop_indices) > 1:
|
||||
logger.warning(
|
||||
f"[LOOP-WARNING] All {len(loop_indices)} results have loop_index=1. "
|
||||
"Perf test markers may not have been parsed correctly."
|
||||
)
|
||||
|
||||
return test_results
|
||||
|
|
|
|||
|
|
@ -2201,6 +2201,7 @@ class JavaScriptSupport:
|
|||
from codeflash.languages.test_framework import get_js_test_framework_or_default
|
||||
|
||||
framework = test_framework or get_js_test_framework_or_default()
|
||||
logger.debug("run_benchmarking_tests called with framework=%s", framework)
|
||||
|
||||
# Use JS-specific high max_loops - actual loop count is limited by target_duration
|
||||
effective_max_loops = self.JS_BENCHMARKING_MAX_LOOPS
|
||||
|
|
@ -2208,6 +2209,7 @@ class JavaScriptSupport:
|
|||
if framework == "vitest":
|
||||
from codeflash.languages.javascript.vitest_runner import run_vitest_benchmarking_tests
|
||||
|
||||
logger.debug("Dispatching to run_vitest_benchmarking_tests")
|
||||
return run_vitest_benchmarking_tests(
|
||||
test_paths=test_paths,
|
||||
test_env=test_env,
|
||||
|
|
|
|||
|
|
@ -192,7 +192,7 @@ def _ensure_codeflash_vitest_config(project_root: Path) -> Path | None:
|
|||
logger.debug("Detected vitest workspace configuration - skipping custom config")
|
||||
return None
|
||||
|
||||
codeflash_config_path = project_root / "codeflash.vitest.config.js"
|
||||
codeflash_config_path = project_root / "codeflash.vitest.config.mjs"
|
||||
|
||||
# If already exists, use it
|
||||
if codeflash_config_path.exists():
|
||||
|
|
@ -281,7 +281,7 @@ def _build_vitest_behavioral_command(
|
|||
|
||||
# For monorepos with restrictive vitest configs (e.g., include: test/**/*.test.ts),
|
||||
# we need to create a custom config that allows all test patterns.
|
||||
# This is done by creating a codeflash.vitest.config.js file.
|
||||
# This is done by creating a codeflash.vitest.config.mjs file.
|
||||
if project_root:
|
||||
codeflash_vitest_config = _ensure_codeflash_vitest_config(project_root)
|
||||
if codeflash_vitest_config:
|
||||
|
|
@ -520,6 +520,9 @@ def run_vitest_benchmarking_tests(
|
|||
) -> tuple[Path, subprocess.CompletedProcess]:
|
||||
"""Run Vitest benchmarking tests with external looping from Python.
|
||||
|
||||
NOTE: This function MUST use benchmarking_file_path (perf tests with capturePerf),
|
||||
NOT instrumented_behavior_file_path (behavior tests with capture).
|
||||
|
||||
Uses external process-level looping to run tests multiple times and
|
||||
collect timing data. This matches the Python pytest approach where
|
||||
looping is controlled externally for simplicity.
|
||||
|
|
@ -544,6 +547,26 @@ def run_vitest_benchmarking_tests(
|
|||
# Get performance test files
|
||||
test_files = [Path(file.benchmarking_file_path) for file in test_paths.test_files if file.benchmarking_file_path]
|
||||
|
||||
# Log test file selection
|
||||
total_test_files = len(test_paths.test_files)
|
||||
perf_test_files = len(test_files)
|
||||
logger.debug(
|
||||
f"Vitest benchmark test file selection: {perf_test_files}/{total_test_files} have benchmarking_file_path"
|
||||
)
|
||||
if perf_test_files == 0:
|
||||
logger.warning("No perf test files found! Cannot run benchmarking tests.")
|
||||
for tf in test_paths.test_files:
|
||||
logger.warning(
|
||||
f"Test file: behavior={tf.instrumented_behavior_file_path}, perf={tf.benchmarking_file_path}"
|
||||
)
|
||||
elif perf_test_files < total_test_files:
|
||||
for tf in test_paths.test_files:
|
||||
if not tf.benchmarking_file_path:
|
||||
logger.warning(f"Missing benchmarking_file_path: behavior={tf.instrumented_behavior_file_path}")
|
||||
else:
|
||||
for tf in test_files[:3]: # Log first 3 perf test files
|
||||
logger.debug(f"Using perf test file: {tf}")
|
||||
|
||||
# Use provided project_root, or detect it as fallback
|
||||
if project_root is None and test_files:
|
||||
project_root = _find_vitest_project_root(test_files[0])
|
||||
|
|
@ -574,14 +597,25 @@ def run_vitest_benchmarking_tests(
|
|||
vitest_env["CODEFLASH_PERF_STABILITY_CHECK"] = "true" if stability_check else "false"
|
||||
vitest_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
|
||||
# Set test module for marker identification (use first test file as reference)
|
||||
if test_files:
|
||||
test_module_path = str(
|
||||
test_files[0].relative_to(effective_cwd)
|
||||
if test_files[0].is_relative_to(effective_cwd)
|
||||
else test_files[0].name
|
||||
)
|
||||
vitest_env["CODEFLASH_TEST_MODULE"] = test_module_path
|
||||
logger.debug(f"[VITEST-BENCH] Set CODEFLASH_TEST_MODULE={test_module_path}")
|
||||
|
||||
# Total timeout for the entire benchmark run
|
||||
total_timeout = max(120, (target_duration_ms // 1000) + 60, timeout or 120)
|
||||
|
||||
logger.debug(f"Running Vitest benchmarking tests: {' '.join(vitest_cmd)}")
|
||||
logger.debug(f"[VITEST-BENCH] Running Vitest benchmarking tests: {' '.join(vitest_cmd)}")
|
||||
logger.debug(
|
||||
f"Vitest benchmarking config: min_loops={min_loops}, max_loops={max_loops}, "
|
||||
f"[VITEST-BENCH] Config: min_loops={min_loops}, max_loops={max_loops}, "
|
||||
f"target_duration={target_duration_ms}ms, stability_check={stability_check}"
|
||||
)
|
||||
logger.debug(f"[VITEST-BENCH] Environment: CODEFLASH_PERF_LOOP_COUNT={vitest_env.get('CODEFLASH_PERF_LOOP_COUNT')}")
|
||||
|
||||
total_start_time = time.time()
|
||||
|
||||
|
|
@ -606,7 +640,27 @@ def run_vitest_benchmarking_tests(
|
|||
result = subprocess.CompletedProcess(args=vitest_cmd, returncode=-1, stdout="", stderr="Vitest not found")
|
||||
|
||||
wall_clock_seconds = time.time() - total_start_time
|
||||
logger.debug(f"Vitest benchmarking completed in {wall_clock_seconds:.2f}s")
|
||||
logger.debug(f"[VITEST-BENCH] Completed in {wall_clock_seconds:.2f}s, returncode={result.returncode}")
|
||||
|
||||
# Debug: Check for END markers with duration (perf test format)
|
||||
if result.stdout:
|
||||
import re
|
||||
|
||||
perf_end_pattern = re.compile(r"!######[^:]+:[^:]+:[^:]+:(\d+):[^:]+:(\d+)######!")
|
||||
perf_matches = list(perf_end_pattern.finditer(result.stdout))
|
||||
if perf_matches:
|
||||
loop_indices = [int(m.group(1)) for m in perf_matches]
|
||||
logger.debug(
|
||||
f"[VITEST-BENCH] Found {len(perf_matches)} perf END markers in stdout, "
|
||||
f"loop_index range: {min(loop_indices)}-{max(loop_indices)}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"[VITEST-BENCH] No perf END markers found in stdout (len={len(result.stdout)})")
|
||||
# Check if there are behavior END markers instead
|
||||
behavior_end_pattern = re.compile(r"!######[^:]+:[^:]+:[^:]+:\d+:[^#]+######!")
|
||||
behavior_matches = list(behavior_end_pattern.finditer(result.stdout))
|
||||
if behavior_matches:
|
||||
logger.debug(f"[VITEST-BENCH] Found {len(behavior_matches)} behavior END markers instead (no duration)")
|
||||
|
||||
return result_file_path, result
|
||||
|
||||
|
|
|
|||
|
|
@ -2368,6 +2368,12 @@ class FunctionOptimizer:
|
|||
)
|
||||
console.rule()
|
||||
with progress_bar("Running performance benchmarks..."):
|
||||
logger.debug(
|
||||
f"[BENCHMARK-START] Starting benchmarking tests with {len(self.test_files.test_files)} test files"
|
||||
)
|
||||
for idx, tf in enumerate(self.test_files.test_files):
|
||||
logger.debug(f"[BENCHMARK-FILES] Test file {idx}: perf_file={tf.benchmarking_file_path}")
|
||||
|
||||
if self.function_to_optimize.is_async and is_python():
|
||||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
|
|
@ -2385,6 +2391,7 @@ class FunctionOptimizer:
|
|||
enable_coverage=False,
|
||||
code_context=code_context,
|
||||
)
|
||||
logger.debug(f"[BENCHMARK-DONE] Got {len(benchmarking_results.test_results)} benchmark results")
|
||||
finally:
|
||||
if self.function_to_optimize.is_async:
|
||||
self.write_code_and_helpers(
|
||||
|
|
|
|||
|
|
@ -325,6 +325,7 @@ def run_benchmarking_tests(
|
|||
pytest_max_loops: int = 100_000,
|
||||
js_project_root: Path | None = None,
|
||||
) -> tuple[Path, subprocess.CompletedProcess]:
|
||||
logger.debug(f"run_benchmarking_tests called: framework={test_framework}, num_files={len(test_paths.test_files)}")
|
||||
# Check if there's a language support for this test framework that implements run_benchmarking_tests
|
||||
language_support = get_language_support_by_framework(test_framework)
|
||||
if language_support is not None and hasattr(language_support, "run_benchmarking_tests"):
|
||||
|
|
|
|||
|
|
@ -79,8 +79,12 @@ def generate_tests(
|
|||
if is_javascript():
|
||||
from codeflash.languages.javascript.instrument import (
|
||||
TestingMode,
|
||||
<<<<<<< fix/js-jest30-loop-runner
|
||||
fix_imports_inside_test_blocks,
|
||||
fix_jest_mock_paths,
|
||||
=======
|
||||
fix_import_path_for_test_location,
|
||||
>>>>>>> main
|
||||
instrument_generated_js_test,
|
||||
validate_and_fix_import_style,
|
||||
)
|
||||
|
|
@ -91,12 +95,19 @@ def generate_tests(
|
|||
|
||||
source_file = Path(function_to_optimize.file_path)
|
||||
|
||||
<<<<<<< fix/js-jest30-loop-runner
|
||||
# Fix import statements that appear inside test blocks (invalid JS syntax)
|
||||
generated_test_source = fix_imports_inside_test_blocks(generated_test_source)
|
||||
|
||||
# Fix relative paths in jest.mock() calls
|
||||
generated_test_source = fix_jest_mock_paths(
|
||||
generated_test_source, test_path, source_file, test_cfg.tests_project_rootdir
|
||||
=======
|
||||
# Fix import paths to be relative to test file location
|
||||
# AI may generate imports like 'apps/web/app/file' instead of '../../app/file'
|
||||
generated_test_source = fix_import_path_for_test_location(
|
||||
generated_test_source, source_file, test_path, module_path
|
||||
>>>>>>> main
|
||||
)
|
||||
|
||||
# Validate and fix import styles (default vs named exports)
|
||||
|
|
|
|||
|
|
@ -992,7 +992,7 @@ function setTestName(name) {
|
|||
resetInvocationCounters();
|
||||
}
|
||||
|
||||
// Jest lifecycle hooks - these run automatically when this module is imported
|
||||
// Jest/Vitest lifecycle hooks - these run automatically when this module is imported
|
||||
if (typeof beforeEach !== 'undefined') {
|
||||
beforeEach(() => {
|
||||
// Get current test name and path from Jest's expect state
|
||||
|
|
@ -1007,6 +1007,17 @@ if (typeof beforeEach !== 'undefined') {
|
|||
}
|
||||
// Reset invocation counters for each test
|
||||
resetInvocationCounters();
|
||||
|
||||
// For Vitest (no external loop-runner), reset perf state for each test
|
||||
// so each test gets its own time budget for internal looping.
|
||||
// For Jest with loop-runner, CODEFLASH_PERF_CURRENT_BATCH is set,
|
||||
// and we want shared state across the test file.
|
||||
const hasExternalLoopRunner = process.env.CODEFLASH_PERF_CURRENT_BATCH !== undefined;
|
||||
if (!hasExternalLoopRunner) {
|
||||
resetPerfState();
|
||||
// Also reset invocation loop counts so each test starts fresh
|
||||
sharedPerfState.invocationLoopCounts = {};
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -950,3 +950,203 @@ def test_filter_functions_non_overlapping_tests_root():
|
|||
|
||||
# Strict check: exactly 2 functions remaining
|
||||
assert count == 2, f"Expected exactly 2 functions, got {count}"
|
||||
|
||||
|
||||
def test_filter_functions_project_inside_tests_folder():
|
||||
"""Test that source files are not filtered when project is inside a folder named 'tests'.
|
||||
|
||||
This is a critical regression test for projects located at paths like:
|
||||
- /home/user/tests/myproject/
|
||||
- /Users/dev/tests/n8n/
|
||||
|
||||
The fix ensures that directory pattern matching (e.g., /tests/) is only checked
|
||||
on the relative path from project_root, not on the full absolute path.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as outer_temp_dir_str:
|
||||
outer_temp_dir = Path(outer_temp_dir_str)
|
||||
|
||||
# Create a "tests" folder to simulate /home/user/tests/
|
||||
tests_parent_folder = outer_temp_dir / "tests"
|
||||
tests_parent_folder.mkdir()
|
||||
|
||||
# Create project inside the "tests" folder - simulates /home/user/tests/myproject/
|
||||
project_dir = tests_parent_folder / "myproject"
|
||||
project_dir.mkdir()
|
||||
|
||||
# Create source file inside the project
|
||||
src_dir = project_dir / "src"
|
||||
src_dir.mkdir()
|
||||
source_file = src_dir / "utils.py"
|
||||
with source_file.open("w") as f:
|
||||
f.write("""
|
||||
def deep_copy(obj):
|
||||
\"\"\"Deep copy an object.\"\"\"
|
||||
import copy
|
||||
return copy.deepcopy(obj)
|
||||
|
||||
def compare_values(a, b):
|
||||
\"\"\"Compare two values.\"\"\"
|
||||
return a == b
|
||||
""")
|
||||
|
||||
# Create another source file directly in project root
|
||||
root_source_file = project_dir / "main.py"
|
||||
with root_source_file.open("w") as f:
|
||||
f.write("""
|
||||
def main():
|
||||
\"\"\"Main entry point.\"\"\"
|
||||
return 0
|
||||
""")
|
||||
|
||||
# Create actual test files that should be filtered
|
||||
project_tests_dir = project_dir / "test"
|
||||
project_tests_dir.mkdir()
|
||||
test_file = project_tests_dir / "test_utils.py"
|
||||
with test_file.open("w") as f:
|
||||
f.write("""
|
||||
def test_deep_copy():
|
||||
return True
|
||||
""")
|
||||
|
||||
# Discover functions
|
||||
all_functions = {}
|
||||
for file_path in [source_file, root_source_file, test_file]:
|
||||
discovered = find_all_functions_in_file(file_path)
|
||||
all_functions.update(discovered)
|
||||
|
||||
# Test: project at /outer/tests/myproject with tests_root overlapping
|
||||
# This simulates: /home/user/tests/n8n with tests_root = /home/user/tests/n8n
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
|
||||
):
|
||||
filtered, count = filter_functions(
|
||||
all_functions,
|
||||
tests_root=project_dir, # Same as project_root (overlapping)
|
||||
ignore_paths=[],
|
||||
project_root=project_dir, # /outer/tests/myproject
|
||||
module_root=project_dir,
|
||||
)
|
||||
|
||||
# Strict check: source files should NOT be filtered even though
|
||||
# the full path contains "/tests/" in the parent directory
|
||||
expected_files = {source_file, root_source_file}
|
||||
actual_files = set(filtered.keys())
|
||||
|
||||
assert actual_files == expected_files, (
|
||||
f"Source files were incorrectly filtered when project is inside 'tests' folder.\n"
|
||||
f"Expected files: {expected_files}\n"
|
||||
f"Got files: {actual_files}\n"
|
||||
f"Project path: {project_dir}\n"
|
||||
f"This indicates the /tests/ pattern matched the parent directory path."
|
||||
)
|
||||
|
||||
# Verify the correct functions are present
|
||||
source_functions = sorted([fn.function_name for fn in filtered.get(source_file, [])])
|
||||
assert source_functions == ["compare_values", "deep_copy"], (
|
||||
f"Expected ['compare_values', 'deep_copy'], got {source_functions}"
|
||||
)
|
||||
|
||||
root_functions = [fn.function_name for fn in filtered.get(root_source_file, [])]
|
||||
assert root_functions == ["main"], (
|
||||
f"Expected ['main'], got {root_functions}"
|
||||
)
|
||||
|
||||
# Strict check: exactly 3 functions (2 from utils.py + 1 from main.py)
|
||||
assert count == 3, (
|
||||
f"Expected exactly 3 functions, got {count}. "
|
||||
f"Some source files may have been incorrectly filtered."
|
||||
)
|
||||
|
||||
# Verify test file was properly filtered (should not be in results)
|
||||
assert test_file not in filtered, (
|
||||
f"Test file {test_file} should have been filtered but wasn't"
|
||||
)
|
||||
|
||||
|
||||
def test_filter_functions_typescript_project_in_tests_folder():
|
||||
"""Test TypeScript-like project structure inside a folder named 'tests'.
|
||||
|
||||
This simulates the n8n project structure:
|
||||
/home/user/tests/n8n/packages/workflow/src/utils.ts
|
||||
|
||||
Ensures that TypeScript source files are not incorrectly filtered
|
||||
when the parent directory happens to be named 'tests'.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as outer_temp_dir_str:
|
||||
outer_temp_dir = Path(outer_temp_dir_str)
|
||||
|
||||
# Simulate: /home/user/tests/n8n
|
||||
tests_folder = outer_temp_dir / "tests"
|
||||
tests_folder.mkdir()
|
||||
n8n_project = tests_folder / "n8n"
|
||||
n8n_project.mkdir()
|
||||
|
||||
# Simulate: packages/workflow/src/utils.py (using .py for testing)
|
||||
packages_dir = n8n_project / "packages"
|
||||
packages_dir.mkdir()
|
||||
workflow_dir = packages_dir / "workflow"
|
||||
workflow_dir.mkdir()
|
||||
src_dir = workflow_dir / "src"
|
||||
src_dir.mkdir()
|
||||
|
||||
# Source file deep in the monorepo structure
|
||||
utils_file = src_dir / "utils.py"
|
||||
with utils_file.open("w") as f:
|
||||
f.write("""
|
||||
def deep_copy(source):
|
||||
\"\"\"Create a deep copy of the source object.\"\"\"
|
||||
if source is None:
|
||||
return None
|
||||
return source.copy() if hasattr(source, 'copy') else source
|
||||
|
||||
def is_object_empty(obj):
|
||||
\"\"\"Check if an object is empty.\"\"\"
|
||||
return len(obj) == 0 if obj else True
|
||||
""")
|
||||
|
||||
# Create test directory inside the package (simulating packages/workflow/test/)
|
||||
test_dir = workflow_dir / "test"
|
||||
test_dir.mkdir()
|
||||
test_file = test_dir / "utils.test.py"
|
||||
with test_file.open("w") as f:
|
||||
f.write("""
|
||||
def test_deep_copy():
|
||||
return True
|
||||
|
||||
def test_is_object_empty():
|
||||
return True
|
||||
""")
|
||||
|
||||
# Discover functions
|
||||
all_functions = {}
|
||||
for file_path in [utils_file, test_file]:
|
||||
discovered = find_all_functions_in_file(file_path)
|
||||
all_functions.update(discovered)
|
||||
|
||||
# Test with module_root = packages (typical TypeScript monorepo setup)
|
||||
with unittest.mock.patch(
|
||||
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
|
||||
):
|
||||
filtered, count = filter_functions(
|
||||
all_functions,
|
||||
tests_root=packages_dir, # Overlapping with module_root
|
||||
ignore_paths=[],
|
||||
project_root=n8n_project, # /outer/tests/n8n
|
||||
module_root=packages_dir, # /outer/tests/n8n/packages
|
||||
)
|
||||
|
||||
# Strict check: only the source file should remain
|
||||
assert set(filtered.keys()) == {utils_file}, (
|
||||
f"Expected only {utils_file} but got {set(filtered.keys())}.\n"
|
||||
f"Source files in /outer/tests/n8n/packages/workflow/src/ were incorrectly filtered.\n"
|
||||
f"The /tests/ pattern in the parent path should not affect filtering."
|
||||
)
|
||||
|
||||
# Verify the correct functions are present
|
||||
filtered_functions = sorted([fn.function_name for fn in filtered.get(utils_file, [])])
|
||||
assert filtered_functions == ["deep_copy", "is_object_empty"], (
|
||||
f"Expected ['deep_copy', 'is_object_empty'], got {filtered_functions}"
|
||||
)
|
||||
|
||||
# Strict check: exactly 2 functions
|
||||
assert count == 2, f"Expected exactly 2 functions, got {count}"
|
||||
Loading…
Reference in a new issue