Merge branch 'main' into fix/js-jest30-loop-runner

This commit is contained in:
Sarthak Agarwal 2026-02-09 21:06:06 +05:30 committed by GitHub
commit ce13a6d534
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 473 additions and 6 deletions

View file

@ -901,6 +901,115 @@ def validate_and_fix_import_style(test_code: str, source_file_path: Path, functi
return test_code
def fix_import_path_for_test_location(
test_code: str, source_file_path: Path, test_file_path: Path, module_root: Path
) -> str:
"""Fix import paths in generated test code to be relative to test file location.
The AI may generate tests with import paths that are relative to the module root
(e.g., 'apps/web/app/file') instead of relative to where the test file is located
(e.g., '../../app/file'). This function fixes such imports.
Args:
test_code: The generated test code.
source_file_path: Absolute path to the source file being tested.
test_file_path: Absolute path to where the test file will be written.
module_root: Root directory of the module/project.
Returns:
Test code with corrected import paths.
"""
import os
# Calculate the correct relative import path from test file to source file
test_dir = test_file_path.parent
try:
correct_rel_path = os.path.relpath(source_file_path, test_dir)
correct_rel_path = correct_rel_path.replace("\\", "/")
# Remove file extension for JS/TS imports
for ext in [".tsx", ".ts", ".jsx", ".js", ".mjs", ".cjs"]:
if correct_rel_path.endswith(ext):
correct_rel_path = correct_rel_path[: -len(ext)]
break
# Ensure it starts with ./ or ../
if not correct_rel_path.startswith("."):
correct_rel_path = "./" + correct_rel_path
except ValueError:
# Can't compute relative path (different drives on Windows)
return test_code
# Try to compute what incorrect path the AI might have generated
# The AI often uses module_root-relative paths like 'apps/web/app/...'
try:
source_rel_to_module = os.path.relpath(source_file_path, module_root)
source_rel_to_module = source_rel_to_module.replace("\\", "/")
# Remove extension
for ext in [".tsx", ".ts", ".jsx", ".js", ".mjs", ".cjs"]:
if source_rel_to_module.endswith(ext):
source_rel_to_module = source_rel_to_module[: -len(ext)]
break
except ValueError:
return test_code
# Also check for project root-relative paths (including module_root in path)
try:
project_root = module_root.parent if module_root.name in ["src", "lib", "app", "web", "apps"] else module_root
source_rel_to_project = os.path.relpath(source_file_path, project_root)
source_rel_to_project = source_rel_to_project.replace("\\", "/")
for ext in [".tsx", ".ts", ".jsx", ".js", ".mjs", ".cjs"]:
if source_rel_to_project.endswith(ext):
source_rel_to_project = source_rel_to_project[: -len(ext)]
break
except ValueError:
source_rel_to_project = None
# Source file name (for matching module paths that end with the file name)
source_name = source_file_path.stem
# Patterns to find import statements
# ESM: import { func } from 'path' or import func from 'path'
esm_import_pattern = re.compile(r"(import\s+(?:{[^}]+}|\w+)\s+from\s+['\"])([^'\"]+)(['\"])")
# CommonJS: const { func } = require('path') or const func = require('path')
cjs_require_pattern = re.compile(
r"((?:const|let|var)\s+(?:{[^}]+}|\w+)\s*=\s*require\s*\(\s*['\"])([^'\"]+)(['\"])"
)
def should_fix_path(import_path: str) -> bool:
"""Check if this import path looks like it should point to our source file."""
# Skip relative imports that already look correct
if import_path.startswith(("./", "../")):
return False
# Skip package imports (no path separators or start with @)
if "/" not in import_path and "\\" not in import_path:
return False
if import_path.startswith("@") and "/" in import_path:
# Could be an alias like @/utils - skip these
return False
# Check if it looks like it points to our source file
if import_path == source_rel_to_module:
return True
if source_rel_to_project and import_path == source_rel_to_project:
return True
if import_path.endswith((source_name, "/" + source_name)):
return True
return False
def fix_import(match: re.Match[str]) -> str:
"""Replace incorrect import path with correct relative path."""
prefix = match.group(1)
import_path = match.group(2)
suffix = match.group(3)
if should_fix_path(import_path):
logger.debug(f"Fixing import path: {import_path} -> {correct_rel_path}")
return f"{prefix}{correct_rel_path}{suffix}"
return match.group(0)
test_code = esm_import_pattern.sub(fix_import, test_code)
return cjs_require_pattern.sub(fix_import, test_code)
def get_instrumented_test_path(original_path: Path, mode: str) -> Path:
"""Generate path for instrumented test file.

View file

@ -175,6 +175,19 @@ def parse_jest_test_xml(
logger.debug(f"Found {marker_count} timing start markers in Jest stdout")
else:
logger.debug(f"No timing start markers found in Jest stdout (len={len(global_stdout)})")
# Check for END markers with duration (perf test markers)
end_marker_count = len(jest_end_pattern.findall(global_stdout))
if end_marker_count > 0:
logger.debug(
f"[PERF-DEBUG] Found {end_marker_count} END timing markers with duration in Jest stdout"
)
# Sample a few markers to verify loop indices
end_samples = list(jest_end_pattern.finditer(global_stdout))[:5]
for sample in end_samples:
groups = sample.groups()
logger.debug(f"[PERF-DEBUG] Sample END marker: loopIndex={groups[3]}, duration={groups[5]}")
else:
logger.debug("[PERF-DEBUG] No END markers with duration found in Jest stdout")
except (AttributeError, UnicodeDecodeError):
global_stdout = ""
@ -197,6 +210,14 @@ def parse_jest_test_xml(
key = match.groups()[:5]
end_matches_dict[key] = match
# Debug: log suite-level END marker parsing for perf tests
if end_matches_dict:
# Get unique loop indices from the parsed END markers
loop_indices = sorted({int(k[3]) if k[3].isdigit() else 1 for k in end_matches_dict})
logger.debug(
f"[PERF-DEBUG] Suite {suite_count}: parsed {len(end_matches_dict)} END markers from suite_stdout, loop_index range: {min(loop_indices)}-{max(loop_indices)}"
)
# Also collect timing markers from testcase-level system-out (Vitest puts output at testcase level)
for tc in suite:
tc_system_out = tc._elem.find("system-out") # noqa: SLF001
@ -327,6 +348,13 @@ def parse_jest_test_xml(
sanitized_test_name = re.sub(r"[!#: ()\[\]{}|\\/*?^$.+\-]", "_", test_name)
matching_starts = [m for m in start_matches if sanitized_test_name in m.group(2)]
# Debug: log which branch we're taking
logger.debug(
f"[FLOW-DEBUG] Testcase '{test_name[:50]}': "
f"total_start_matches={len(start_matches)}, matching_starts={len(matching_starts)}, "
f"total_end_matches={len(end_matches_dict)}"
)
# For performance tests (capturePerf), there are no START markers - only END markers with duration
# Check for END markers directly if no START markers found
matching_ends_direct = []
@ -337,6 +365,28 @@ def parse_jest_test_xml(
# end_key is (module, testName, funcName, loopIndex, invocationId)
if len(end_key) >= 2 and sanitized_test_name in end_key[1]:
matching_ends_direct.append(end_match)
# Debug: log matching results for perf tests
if matching_ends_direct:
loop_indices = [int(m.groups()[3]) if m.groups()[3].isdigit() else 1 for m in matching_ends_direct]
logger.debug(
f"[PERF-MATCH] Testcase '{test_name[:40]}': matched {len(matching_ends_direct)} END markers, "
f"loop_index range: {min(loop_indices)}-{max(loop_indices)}"
)
elif end_matches_dict:
# No matches but we have END markers - check why
sample_keys = list(end_matches_dict.keys())[:3]
logger.debug(
f"[PERF-MISMATCH] Testcase '{test_name[:40]}': no matches found. "
f"sanitized_test_name='{sanitized_test_name[:50]}', "
f"sample end_keys={[k[1][:30] if len(k) >= 2 else k for k in sample_keys]}"
)
# Log if we're skipping the matching_ends_direct branch
if matching_starts and end_matches_dict:
logger.debug(
f"[FLOW-SKIP] Testcase '{test_name[:40]}': has {len(matching_starts)} START markers, "
f"skipping {len(end_matches_dict)} END markers (behavior test mode)"
)
if not matching_starts and not matching_ends_direct:
# No timing markers found - use JUnit XML time attribute as fallback
@ -373,11 +423,13 @@ def parse_jest_test_xml(
)
elif matching_ends_direct:
# Performance test format: process END markers directly (no START markers)
loop_indices_found = []
for end_match in matching_ends_direct:
groups = end_match.groups()
# groups: (module, testName, funcName, loopIndex, invocationId, durationNs)
func_name = groups[2]
loop_index = int(groups[3]) if groups[3].isdigit() else 1
loop_indices_found.append(loop_index)
line_id = groups[4]
try:
runtime = int(groups[5])
@ -403,6 +455,12 @@ def parse_jest_test_xml(
stdout="",
)
)
if loop_indices_found:
logger.debug(
f"[LOOP-DEBUG] Testcase '{test_name}': processed {len(matching_ends_direct)} END markers, "
f"loop_index range: {min(loop_indices_found)}-{max(loop_indices_found)}, "
f"total results so far: {len(test_results.test_results)}"
)
else:
# Process each timing marker
for match in matching_starts:
@ -454,5 +512,19 @@ def parse_jest_test_xml(
f"Jest XML parsing complete: {len(test_results.test_results)} results "
f"from {suite_count} suites, {testcase_count} testcases"
)
# Debug: show loop_index distribution for perf analysis
if test_results.test_results:
loop_indices = [r.loop_index for r in test_results.test_results]
unique_loop_indices = sorted(set(loop_indices))
min_idx, max_idx = min(unique_loop_indices), max(unique_loop_indices)
logger.debug(
f"[LOOP-SUMMARY] Results loop_index: min={min_idx}, max={max_idx}, "
f"unique_count={len(unique_loop_indices)}, total_results={len(loop_indices)}"
)
if max_idx == 1 and len(loop_indices) > 1:
logger.warning(
f"[LOOP-WARNING] All {len(loop_indices)} results have loop_index=1. "
"Perf test markers may not have been parsed correctly."
)
return test_results

View file

@ -2201,6 +2201,7 @@ class JavaScriptSupport:
from codeflash.languages.test_framework import get_js_test_framework_or_default
framework = test_framework or get_js_test_framework_or_default()
logger.debug("run_benchmarking_tests called with framework=%s", framework)
# Use JS-specific high max_loops - actual loop count is limited by target_duration
effective_max_loops = self.JS_BENCHMARKING_MAX_LOOPS
@ -2208,6 +2209,7 @@ class JavaScriptSupport:
if framework == "vitest":
from codeflash.languages.javascript.vitest_runner import run_vitest_benchmarking_tests
logger.debug("Dispatching to run_vitest_benchmarking_tests")
return run_vitest_benchmarking_tests(
test_paths=test_paths,
test_env=test_env,

View file

@ -192,7 +192,7 @@ def _ensure_codeflash_vitest_config(project_root: Path) -> Path | None:
logger.debug("Detected vitest workspace configuration - skipping custom config")
return None
codeflash_config_path = project_root / "codeflash.vitest.config.js"
codeflash_config_path = project_root / "codeflash.vitest.config.mjs"
# If already exists, use it
if codeflash_config_path.exists():
@ -281,7 +281,7 @@ def _build_vitest_behavioral_command(
# For monorepos with restrictive vitest configs (e.g., include: test/**/*.test.ts),
# we need to create a custom config that allows all test patterns.
# This is done by creating a codeflash.vitest.config.js file.
# This is done by creating a codeflash.vitest.config.mjs file.
if project_root:
codeflash_vitest_config = _ensure_codeflash_vitest_config(project_root)
if codeflash_vitest_config:
@ -520,6 +520,9 @@ def run_vitest_benchmarking_tests(
) -> tuple[Path, subprocess.CompletedProcess]:
"""Run Vitest benchmarking tests with external looping from Python.
NOTE: This function MUST use benchmarking_file_path (perf tests with capturePerf),
NOT instrumented_behavior_file_path (behavior tests with capture).
Uses external process-level looping to run tests multiple times and
collect timing data. This matches the Python pytest approach where
looping is controlled externally for simplicity.
@ -544,6 +547,26 @@ def run_vitest_benchmarking_tests(
# Get performance test files
test_files = [Path(file.benchmarking_file_path) for file in test_paths.test_files if file.benchmarking_file_path]
# Log test file selection
total_test_files = len(test_paths.test_files)
perf_test_files = len(test_files)
logger.debug(
f"Vitest benchmark test file selection: {perf_test_files}/{total_test_files} have benchmarking_file_path"
)
if perf_test_files == 0:
logger.warning("No perf test files found! Cannot run benchmarking tests.")
for tf in test_paths.test_files:
logger.warning(
f"Test file: behavior={tf.instrumented_behavior_file_path}, perf={tf.benchmarking_file_path}"
)
elif perf_test_files < total_test_files:
for tf in test_paths.test_files:
if not tf.benchmarking_file_path:
logger.warning(f"Missing benchmarking_file_path: behavior={tf.instrumented_behavior_file_path}")
else:
for tf in test_files[:3]: # Log first 3 perf test files
logger.debug(f"Using perf test file: {tf}")
# Use provided project_root, or detect it as fallback
if project_root is None and test_files:
project_root = _find_vitest_project_root(test_files[0])
@ -574,14 +597,25 @@ def run_vitest_benchmarking_tests(
vitest_env["CODEFLASH_PERF_STABILITY_CHECK"] = "true" if stability_check else "false"
vitest_env["CODEFLASH_LOOP_INDEX"] = "1"
# Set test module for marker identification (use first test file as reference)
if test_files:
test_module_path = str(
test_files[0].relative_to(effective_cwd)
if test_files[0].is_relative_to(effective_cwd)
else test_files[0].name
)
vitest_env["CODEFLASH_TEST_MODULE"] = test_module_path
logger.debug(f"[VITEST-BENCH] Set CODEFLASH_TEST_MODULE={test_module_path}")
# Total timeout for the entire benchmark run
total_timeout = max(120, (target_duration_ms // 1000) + 60, timeout or 120)
logger.debug(f"Running Vitest benchmarking tests: {' '.join(vitest_cmd)}")
logger.debug(f"[VITEST-BENCH] Running Vitest benchmarking tests: {' '.join(vitest_cmd)}")
logger.debug(
f"Vitest benchmarking config: min_loops={min_loops}, max_loops={max_loops}, "
f"[VITEST-BENCH] Config: min_loops={min_loops}, max_loops={max_loops}, "
f"target_duration={target_duration_ms}ms, stability_check={stability_check}"
)
logger.debug(f"[VITEST-BENCH] Environment: CODEFLASH_PERF_LOOP_COUNT={vitest_env.get('CODEFLASH_PERF_LOOP_COUNT')}")
total_start_time = time.time()
@ -606,7 +640,27 @@ def run_vitest_benchmarking_tests(
result = subprocess.CompletedProcess(args=vitest_cmd, returncode=-1, stdout="", stderr="Vitest not found")
wall_clock_seconds = time.time() - total_start_time
logger.debug(f"Vitest benchmarking completed in {wall_clock_seconds:.2f}s")
logger.debug(f"[VITEST-BENCH] Completed in {wall_clock_seconds:.2f}s, returncode={result.returncode}")
# Debug: Check for END markers with duration (perf test format)
if result.stdout:
import re
perf_end_pattern = re.compile(r"!######[^:]+:[^:]+:[^:]+:(\d+):[^:]+:(\d+)######!")
perf_matches = list(perf_end_pattern.finditer(result.stdout))
if perf_matches:
loop_indices = [int(m.group(1)) for m in perf_matches]
logger.debug(
f"[VITEST-BENCH] Found {len(perf_matches)} perf END markers in stdout, "
f"loop_index range: {min(loop_indices)}-{max(loop_indices)}"
)
else:
logger.debug(f"[VITEST-BENCH] No perf END markers found in stdout (len={len(result.stdout)})")
# Check if there are behavior END markers instead
behavior_end_pattern = re.compile(r"!######[^:]+:[^:]+:[^:]+:\d+:[^#]+######!")
behavior_matches = list(behavior_end_pattern.finditer(result.stdout))
if behavior_matches:
logger.debug(f"[VITEST-BENCH] Found {len(behavior_matches)} behavior END markers instead (no duration)")
return result_file_path, result

View file

@ -2368,6 +2368,12 @@ class FunctionOptimizer:
)
console.rule()
with progress_bar("Running performance benchmarks..."):
logger.debug(
f"[BENCHMARK-START] Starting benchmarking tests with {len(self.test_files.test_files)} test files"
)
for idx, tf in enumerate(self.test_files.test_files):
logger.debug(f"[BENCHMARK-FILES] Test file {idx}: perf_file={tf.benchmarking_file_path}")
if self.function_to_optimize.is_async and is_python():
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
@ -2385,6 +2391,7 @@ class FunctionOptimizer:
enable_coverage=False,
code_context=code_context,
)
logger.debug(f"[BENCHMARK-DONE] Got {len(benchmarking_results.test_results)} benchmark results")
finally:
if self.function_to_optimize.is_async:
self.write_code_and_helpers(

View file

@ -325,6 +325,7 @@ def run_benchmarking_tests(
pytest_max_loops: int = 100_000,
js_project_root: Path | None = None,
) -> tuple[Path, subprocess.CompletedProcess]:
logger.debug(f"run_benchmarking_tests called: framework={test_framework}, num_files={len(test_paths.test_files)}")
# Check if there's a language support for this test framework that implements run_benchmarking_tests
language_support = get_language_support_by_framework(test_framework)
if language_support is not None and hasattr(language_support, "run_benchmarking_tests"):

View file

@ -79,8 +79,12 @@ def generate_tests(
if is_javascript():
from codeflash.languages.javascript.instrument import (
TestingMode,
<<<<<<< fix/js-jest30-loop-runner
fix_imports_inside_test_blocks,
fix_jest_mock_paths,
=======
fix_import_path_for_test_location,
>>>>>>> main
instrument_generated_js_test,
validate_and_fix_import_style,
)
@ -91,12 +95,19 @@ def generate_tests(
source_file = Path(function_to_optimize.file_path)
<<<<<<< fix/js-jest30-loop-runner
# Fix import statements that appear inside test blocks (invalid JS syntax)
generated_test_source = fix_imports_inside_test_blocks(generated_test_source)
# Fix relative paths in jest.mock() calls
generated_test_source = fix_jest_mock_paths(
generated_test_source, test_path, source_file, test_cfg.tests_project_rootdir
=======
# Fix import paths to be relative to test file location
# AI may generate imports like 'apps/web/app/file' instead of '../../app/file'
generated_test_source = fix_import_path_for_test_location(
generated_test_source, source_file, test_path, module_path
>>>>>>> main
)
# Validate and fix import styles (default vs named exports)

View file

@ -992,7 +992,7 @@ function setTestName(name) {
resetInvocationCounters();
}
// Jest lifecycle hooks - these run automatically when this module is imported
// Jest/Vitest lifecycle hooks - these run automatically when this module is imported
if (typeof beforeEach !== 'undefined') {
beforeEach(() => {
// Get current test name and path from Jest's expect state
@ -1007,6 +1007,17 @@ if (typeof beforeEach !== 'undefined') {
}
// Reset invocation counters for each test
resetInvocationCounters();
// For Vitest (no external loop-runner), reset perf state for each test
// so each test gets its own time budget for internal looping.
// For Jest with loop-runner, CODEFLASH_PERF_CURRENT_BATCH is set,
// and we want shared state across the test file.
const hasExternalLoopRunner = process.env.CODEFLASH_PERF_CURRENT_BATCH !== undefined;
if (!hasExternalLoopRunner) {
resetPerfState();
// Also reset invocation loop counts so each test starts fresh
sharedPerfState.invocationLoopCounts = {};
}
});
}

View file

@ -950,3 +950,203 @@ def test_filter_functions_non_overlapping_tests_root():
# Strict check: exactly 2 functions remaining
assert count == 2, f"Expected exactly 2 functions, got {count}"
def test_filter_functions_project_inside_tests_folder():
"""Test that source files are not filtered when project is inside a folder named 'tests'.
This is a critical regression test for projects located at paths like:
- /home/user/tests/myproject/
- /Users/dev/tests/n8n/
The fix ensures that directory pattern matching (e.g., /tests/) is only checked
on the relative path from project_root, not on the full absolute path.
"""
with tempfile.TemporaryDirectory() as outer_temp_dir_str:
outer_temp_dir = Path(outer_temp_dir_str)
# Create a "tests" folder to simulate /home/user/tests/
tests_parent_folder = outer_temp_dir / "tests"
tests_parent_folder.mkdir()
# Create project inside the "tests" folder - simulates /home/user/tests/myproject/
project_dir = tests_parent_folder / "myproject"
project_dir.mkdir()
# Create source file inside the project
src_dir = project_dir / "src"
src_dir.mkdir()
source_file = src_dir / "utils.py"
with source_file.open("w") as f:
f.write("""
def deep_copy(obj):
\"\"\"Deep copy an object.\"\"\"
import copy
return copy.deepcopy(obj)
def compare_values(a, b):
\"\"\"Compare two values.\"\"\"
return a == b
""")
# Create another source file directly in project root
root_source_file = project_dir / "main.py"
with root_source_file.open("w") as f:
f.write("""
def main():
\"\"\"Main entry point.\"\"\"
return 0
""")
# Create actual test files that should be filtered
project_tests_dir = project_dir / "test"
project_tests_dir.mkdir()
test_file = project_tests_dir / "test_utils.py"
with test_file.open("w") as f:
f.write("""
def test_deep_copy():
return True
""")
# Discover functions
all_functions = {}
for file_path in [source_file, root_source_file, test_file]:
discovered = find_all_functions_in_file(file_path)
all_functions.update(discovered)
# Test: project at /outer/tests/myproject with tests_root overlapping
# This simulates: /home/user/tests/n8n with tests_root = /home/user/tests/n8n
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
):
filtered, count = filter_functions(
all_functions,
tests_root=project_dir, # Same as project_root (overlapping)
ignore_paths=[],
project_root=project_dir, # /outer/tests/myproject
module_root=project_dir,
)
# Strict check: source files should NOT be filtered even though
# the full path contains "/tests/" in the parent directory
expected_files = {source_file, root_source_file}
actual_files = set(filtered.keys())
assert actual_files == expected_files, (
f"Source files were incorrectly filtered when project is inside 'tests' folder.\n"
f"Expected files: {expected_files}\n"
f"Got files: {actual_files}\n"
f"Project path: {project_dir}\n"
f"This indicates the /tests/ pattern matched the parent directory path."
)
# Verify the correct functions are present
source_functions = sorted([fn.function_name for fn in filtered.get(source_file, [])])
assert source_functions == ["compare_values", "deep_copy"], (
f"Expected ['compare_values', 'deep_copy'], got {source_functions}"
)
root_functions = [fn.function_name for fn in filtered.get(root_source_file, [])]
assert root_functions == ["main"], (
f"Expected ['main'], got {root_functions}"
)
# Strict check: exactly 3 functions (2 from utils.py + 1 from main.py)
assert count == 3, (
f"Expected exactly 3 functions, got {count}. "
f"Some source files may have been incorrectly filtered."
)
# Verify test file was properly filtered (should not be in results)
assert test_file not in filtered, (
f"Test file {test_file} should have been filtered but wasn't"
)
def test_filter_functions_typescript_project_in_tests_folder():
"""Test TypeScript-like project structure inside a folder named 'tests'.
This simulates the n8n project structure:
/home/user/tests/n8n/packages/workflow/src/utils.ts
Ensures that TypeScript source files are not incorrectly filtered
when the parent directory happens to be named 'tests'.
"""
with tempfile.TemporaryDirectory() as outer_temp_dir_str:
outer_temp_dir = Path(outer_temp_dir_str)
# Simulate: /home/user/tests/n8n
tests_folder = outer_temp_dir / "tests"
tests_folder.mkdir()
n8n_project = tests_folder / "n8n"
n8n_project.mkdir()
# Simulate: packages/workflow/src/utils.py (using .py for testing)
packages_dir = n8n_project / "packages"
packages_dir.mkdir()
workflow_dir = packages_dir / "workflow"
workflow_dir.mkdir()
src_dir = workflow_dir / "src"
src_dir.mkdir()
# Source file deep in the monorepo structure
utils_file = src_dir / "utils.py"
with utils_file.open("w") as f:
f.write("""
def deep_copy(source):
\"\"\"Create a deep copy of the source object.\"\"\"
if source is None:
return None
return source.copy() if hasattr(source, 'copy') else source
def is_object_empty(obj):
\"\"\"Check if an object is empty.\"\"\"
return len(obj) == 0 if obj else True
""")
# Create test directory inside the package (simulating packages/workflow/test/)
test_dir = workflow_dir / "test"
test_dir.mkdir()
test_file = test_dir / "utils.test.py"
with test_file.open("w") as f:
f.write("""
def test_deep_copy():
return True
def test_is_object_empty():
return True
""")
# Discover functions
all_functions = {}
for file_path in [utils_file, test_file]:
discovered = find_all_functions_in_file(file_path)
all_functions.update(discovered)
# Test with module_root = packages (typical TypeScript monorepo setup)
with unittest.mock.patch(
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
):
filtered, count = filter_functions(
all_functions,
tests_root=packages_dir, # Overlapping with module_root
ignore_paths=[],
project_root=n8n_project, # /outer/tests/n8n
module_root=packages_dir, # /outer/tests/n8n/packages
)
# Strict check: only the source file should remain
assert set(filtered.keys()) == {utils_file}, (
f"Expected only {utils_file} but got {set(filtered.keys())}.\n"
f"Source files in /outer/tests/n8n/packages/workflow/src/ were incorrectly filtered.\n"
f"The /tests/ pattern in the parent path should not affect filtering."
)
# Verify the correct functions are present
filtered_functions = sorted([fn.function_name for fn in filtered.get(utils_file, [])])
assert filtered_functions == ["deep_copy", "is_object_empty"], (
f"Expected ['deep_copy', 'is_object_empty'], got {filtered_functions}"
)
# Strict check: exactly 2 functions
assert count == 2, f"Expected exactly 2 functions, got {count}"