fix: deduplicate test count calls, guard None, and log effort escalation

Build test_count_cache once before ranking instead of calling
existing_unit_test_count O(2N) times. Guard for None function_to_tests
and add debug logging when effort is escalated from medium to high.
This commit is contained in:
Kevin Turcios 2026-03-16 14:41:55 -06:00
parent 1d1d183075
commit 2cafadb980
2 changed files with 43 additions and 36 deletions

View file

@ -332,7 +332,7 @@ class Optimizer:
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]],
trace_file_path: Path | None,
call_graph: DependencyResolver | None = None,
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
test_count_cache: dict[tuple[Path, str], int] | None = None,
) -> list[tuple[Path, FunctionToOptimize]]:
"""Rank all functions globally across all files based on trace data.
@ -355,7 +355,7 @@ class Optimizer:
# If no trace file, rank by dependency count if call graph is available
if not trace_file_path or not trace_file_path.exists():
if call_graph is not None:
return self.rank_by_dependency_count(all_functions, call_graph, function_to_tests=function_to_tests)
return self.rank_by_dependency_count(all_functions, call_graph, test_count_cache=test_count_cache)
logger.debug("No trace file available, using original function order")
return all_functions
@ -391,15 +391,9 @@ class Optimizer:
(file_path, func, ranker.get_function_addressable_time(func), rank_index)
)
if function_to_tests:
from codeflash.discovery.discover_unit_tests import existing_unit_test_count
if test_count_cache:
ranked_with_metadata.sort(
key=lambda item: (
-item[2],
-existing_unit_test_count(item[1], self.args.project_root, function_to_tests),
item[3],
)
key=lambda item: (-item[2], -test_count_cache.get((item[0], item[1].qualified_name), 0), item[3])
)
globally_ranked = [
@ -427,7 +421,7 @@ class Optimizer:
self,
all_functions: list[tuple[Path, FunctionToOptimize]],
call_graph: DependencyResolver,
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
test_count_cache: dict[tuple[Path, str], int] | None = None,
) -> list[tuple[Path, FunctionToOptimize]]:
file_to_qns: dict[Path, set[str]] = defaultdict(set)
for file_path, func in all_functions:
@ -435,14 +429,12 @@ class Optimizer:
callee_counts = call_graph.count_callees_per_function(dict(file_to_qns))
self._cached_callee_counts = callee_counts
if function_to_tests:
from codeflash.discovery.discover_unit_tests import existing_unit_test_count
if test_count_cache:
ranked = sorted(
enumerate(all_functions),
key=lambda x: (
-callee_counts.get((x[1][0], x[1][1].qualified_name), 0),
-existing_unit_test_count(x[1][1], self.args.project_root, function_to_tests),
-test_count_cache.get((x[1][0], x[1][1].qualified_name), 0),
x[0],
),
)
@ -531,9 +523,21 @@ class Optimizer:
if self.args.all and not self.args.subagent:
self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root)
# Pre-compute test counts once for ranking and logging
if function_to_tests:
from codeflash.discovery.discover_unit_tests import existing_unit_test_count
test_count_cache: dict[tuple[Path, str], int] = {
(fp, fn.qualified_name): existing_unit_test_count(fn, self.args.project_root, function_to_tests)
for fp, fns in file_to_funcs_to_optimize.items()
for fn in fns
}
else:
test_count_cache: dict[tuple[Path, str], int] = {}
# GLOBAL RANKING: Rank all functions together before optimizing
globally_ranked_functions = self.rank_all_functions_globally(
file_to_funcs_to_optimize, trace_file_path, call_graph=resolver, function_to_tests=function_to_tests
file_to_funcs_to_optimize, trace_file_path, call_graph=resolver, test_count_cache=test_count_cache
)
# Cache for module preparation (avoid re-parsing same files)
prepared_modules: dict[Path, tuple[dict[Path, ValidCode], ast.Module | None]] = {}
@ -546,14 +550,6 @@ class Optimizer:
file_to_qns[fp].add(fn.qualified_name)
callee_counts = resolver.count_callees_per_function(dict(file_to_qns))
from codeflash.discovery.discover_unit_tests import existing_unit_test_count
# Pre-compute test counts for logging (already computed during ranking, avoid re-filtering)
test_count_cache: dict[tuple[Path, str], int] = {
(fp, fn.qualified_name): existing_unit_test_count(fn, self.args.project_root, function_to_tests)
for fp, fn in globally_ranked_functions
}
# Optimize functions in globally ranked order
for i, (original_module_path, function_to_optimize) in enumerate(globally_ranked_functions):
# Prepare module if not already cached
@ -578,6 +574,10 @@ class Optimizer:
effort_override: str | None = None
if i < HIGH_EFFORT_TOP_N and self.args.effort == EffortLevel.MEDIUM.value:
effort_override = EffortLevel.HIGH.value
logger.debug(
f"Escalating effort for {function_to_optimize.qualified_name} from medium to high"
f" (top {HIGH_EFFORT_TOP_N} ranked)"
)
logger.info(
f"Optimizing function {function_iterator_count} of {len(globally_ranked_functions)}: "

View file

@ -26,6 +26,15 @@ def make_test(test_type: TestType, test_name: str = "test_something") -> Functio
)
def build_test_count_cache(
funcs: list[FunctionToOptimize], project_root: Path, function_to_tests: dict[str, set[FunctionCalledInTest]]
) -> dict[tuple[Path, str], int]:
return {
(func.file_path, func.qualified_name): existing_unit_test_count(func, project_root, function_to_tests)
for func in funcs
}
def make_optimizer(project_root: Path) -> Optimizer:
def _noop_display_global_ranking(*_args: object, **_kwargs: object) -> None:
return None
@ -177,7 +186,9 @@ def test_trace_ranking_keeps_addressable_time_primary_over_test_count(project_ro
with patch("codeflash.benchmarking.function_ranker.FunctionRanker", FakeRanker):
ranked = optimizer.rank_all_functions_globally(
{project_root / "mod.py": funcs}, trace_file, function_to_tests=function_to_tests
{project_root / "mod.py": funcs},
trace_file,
test_count_cache=build_test_count_cache(funcs, project_root, function_to_tests),
)
assert [func.function_name for _, func in ranked] == ["foo", "bar", "baz"]
@ -214,7 +225,9 @@ def test_trace_ranking_uses_test_count_as_tiebreaker(project_root: Path, tmp_pat
with patch("codeflash.benchmarking.function_ranker.FunctionRanker", FakeRanker):
ranked = optimizer.rank_all_functions_globally(
{project_root / "mod.py": funcs}, trace_file, function_to_tests=function_to_tests
{project_root / "mod.py": funcs},
trace_file,
test_count_cache=build_test_count_cache(funcs, project_root, function_to_tests),
)
assert [func.function_name for _, func in ranked] == ["bar", "foo", "baz"]
@ -233,15 +246,12 @@ def test_dependency_count_ranking_keeps_callee_count_primary(project_root: Path)
class FakeResolver:
def count_callees_per_function(self, _mapping: dict[Path, set[str]]) -> dict[tuple[Path, str], int]:
return {
(project_root / "mod.py", "foo"): 5,
(project_root / "mod.py", "bar"): 1,
}
return {(project_root / "mod.py", "foo"): 5, (project_root / "mod.py", "bar"): 1}
ranked = optimizer.rank_by_dependency_count(
[(project_root / "mod.py", funcs[0]), (project_root / "mod.py", funcs[1])],
FakeResolver(),
function_to_tests=function_to_tests,
test_count_cache=build_test_count_cache(funcs, project_root, function_to_tests),
)
assert [func.function_name for _, func in ranked] == ["foo", "bar"]
@ -263,15 +273,12 @@ def test_dependency_count_ranking_uses_test_count_as_tiebreaker(project_root: Pa
class FakeResolver:
def count_callees_per_function(self, _mapping: dict[Path, set[str]]) -> dict[tuple[Path, str], int]:
return {
(project_root / "mod.py", "foo"): 2,
(project_root / "mod.py", "bar"): 2,
}
return {(project_root / "mod.py", "foo"): 2, (project_root / "mod.py", "bar"): 2}
ranked = optimizer.rank_by_dependency_count(
[(project_root / "mod.py", funcs[0]), (project_root / "mod.py", funcs[1])],
FakeResolver(),
function_to_tests=function_to_tests,
test_count_cache=build_test_count_cache(funcs, project_root, function_to_tests),
)
assert [func.function_name for _, func in ranked] == ["bar", "foo"]