From cee12fe430d170f0c8b0091647f6909098dc8eb7 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sun, 15 Mar 2026 23:29:35 -0600 Subject: [PATCH] fix ranking boost ordering and statement helper extraction --- codeflash/languages/python/reference_graph.py | 18 ++- codeflash/optimization/optimizer.py | 20 ++- tests/test_call_graph.py | 20 +++ tests/test_ranking_boost.py | 141 ++++++++++++++++++ 4 files changed, 191 insertions(+), 8 deletions(-) diff --git a/codeflash/languages/python/reference_graph.py b/codeflash/languages/python/reference_graph.py index f01185bae..949d1c10b 100644 --- a/codeflash/languages/python/reference_graph.py +++ b/codeflash/languages/python/reference_graph.py @@ -42,7 +42,7 @@ def _init_index_worker(project_root: str) -> None: def _resolve_definitions(ref: Name) -> list[Name]: try: inferred = ref.infer() - valid = [d for d in inferred if d.type in ("function", "class")] + valid = [d for d in inferred if d.type in ("function", "class", "statement")] if valid: return valid except Exception: @@ -69,7 +69,7 @@ def _is_valid_definition(definition: Name, caller_qualified_name: str, project_r if not definition.full_name or not definition.full_name.startswith(definition.module_name): return False - if definition.type not in ("function", "class"): + if definition.type not in ("function", "class", "statement"): return False try: @@ -164,6 +164,20 @@ def _analyze_file(file_path: Path, jedi_project: object, project_root_str: str) definition.get_line_code(), ) ) + elif definition.type == "statement": + callee_qn = get_qualified_name(definition.module_name, definition.full_name) + if len(callee_qn.split(".")) > 2: + continue + edges.add( + ( + *edge_base, + callee_qn, + definition.full_name, + definition.name, + definition.type, + definition.get_line_code(), + ) + ) except Exception: continue diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 81e280c30..deb0911ba 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -382,22 +382,30 @@ class Optimizer: # Use a tuple of unique identifiers as the key key: tuple[Path, str, int | None] = (func.file_path, func.qualified_name, func.starting_line) func_to_file_map[key] = file_path - globally_ranked = [] - for func in ranked_functions: + ranked_with_metadata: list[tuple[Path, FunctionToOptimize, float, int]] = [] + for rank_index, func in enumerate(ranked_functions): key = (func.file_path, func.qualified_name, func.starting_line) file_path = func_to_file_map.get(key) if file_path: - globally_ranked.append((file_path, func)) + ranked_with_metadata.append( + (file_path, func, ranker.get_function_addressable_time(func), rank_index) + ) if function_to_tests: from codeflash.discovery.discover_unit_tests import existing_unit_test_count - globally_ranked.sort( + ranked_with_metadata.sort( key=lambda item: ( - 0 if existing_unit_test_count(item[1], self.args.project_root, function_to_tests) > 0 else 1 + -item[2], + -existing_unit_test_count(item[1], self.args.project_root, function_to_tests), + item[3], ) ) + globally_ranked = [ + (file_path, func) for file_path, func, _addressable_time, _rank_index in ranked_with_metadata + ] + console.rule() logger.info( f"Globally ranked {len(ranked_functions)} functions by addressable time " @@ -433,8 +441,8 @@ class Optimizer: ranked = sorted( enumerate(all_functions), key=lambda x: ( - -existing_unit_test_count(x[1][1], self.args.project_root, function_to_tests), -callee_counts.get((x[1][0], x[1][1].qualified_name), 0), + -existing_unit_test_count(x[1][1], self.args.project_root, function_to_tests), x[0], ), ) diff --git a/tests/test_call_graph.py b/tests/test_call_graph.py index a717e20bb..8b038d13e 100644 --- a/tests/test_call_graph.py +++ b/tests/test_call_graph.py @@ -430,6 +430,26 @@ def caller(): finally: cg.close() + def test_get_callees_includes_statement_dependencies(self, project: Path, db_path: Path) -> None: + write_file( + project, + "mod.py", + """\ +X = 1 + +def caller(): + return X + 1 +""", + ) + cg = ReferenceGraph(project, db_path=db_path) + try: + _, function_sources = cg.get_callees({project / "mod.py": {"caller"}}) + assert [(source.qualified_name, source.definition_type) for source in function_sources] == [ + ("X", "statement") + ] + finally: + cg.close() + # --------------------------------------------------------------------------- # CalleeMetadata unit tests diff --git a/tests/test_ranking_boost.py b/tests/test_ranking_boost.py index 38f7720c3..01938f30d 100644 --- a/tests/test_ranking_boost.py +++ b/tests/test_ranking_boost.py @@ -1,6 +1,8 @@ from __future__ import annotations +from argparse import Namespace from pathlib import Path +from unittest.mock import patch import pytest @@ -8,6 +10,7 @@ from codeflash.discovery.discover_unit_tests import existing_unit_test_count from codeflash.models.function_types import FunctionToOptimize from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile from codeflash.models.test_type import TestType +from codeflash.optimization.optimizer import Optimizer def make_func(name: str, project_root: Path) -> FunctionToOptimize: @@ -23,6 +26,16 @@ def make_test(test_type: TestType, test_name: str = "test_something") -> Functio ) +def make_optimizer(project_root: Path) -> Optimizer: + def _noop_display_global_ranking(*_args: object, **_kwargs: object) -> None: + return None + + optimizer = Optimizer.__new__(Optimizer) + optimizer.args = Namespace(project_root=project_root) + optimizer.display_global_ranking = _noop_display_global_ranking + return optimizer + + @pytest.fixture def project_root(tmp_path: Path) -> Path: root = tmp_path / "project" @@ -134,3 +147,131 @@ def test_parametrized_tests_deduplication(project_root: Path) -> None: } } assert existing_unit_test_count(func, project_root, tests) == 2 + + +def test_trace_ranking_keeps_addressable_time_primary_over_test_count(project_root: Path, tmp_path: Path) -> None: + optimizer = make_optimizer(project_root) + funcs = [make_func(name, project_root) for name in ("foo", "bar", "baz")] + trace_file = tmp_path / "trace.db" + trace_file.touch() + + ranked_functions = [funcs[0], funcs[1], funcs[2]] + addressable_times = {"foo": 100.0, "bar": 20.0, "baz": 5.0} + function_to_tests: dict[str, set[FunctionCalledInTest]] = { + funcs[1].qualified_name_with_modules_from_root(project_root): { + make_test(TestType.EXISTING_UNIT_TEST, "test_one"), + make_test(TestType.EXISTING_UNIT_TEST, "test_two"), + make_test(TestType.EXISTING_UNIT_TEST, "test_three"), + } + } + + class FakeRanker: + def __init__(self, _trace_file: Path) -> None: + pass + + def rank_functions(self, _functions: list[FunctionToOptimize]) -> list[FunctionToOptimize]: + return ranked_functions + + def get_function_addressable_time(self, function: FunctionToOptimize) -> float: + return addressable_times[function.function_name] + + with patch("codeflash.benchmarking.function_ranker.FunctionRanker", FakeRanker): + ranked = optimizer.rank_all_functions_globally( + {project_root / "mod.py": funcs}, trace_file, function_to_tests=function_to_tests + ) + + assert [func.function_name for _, func in ranked] == ["foo", "bar", "baz"] + + +def test_trace_ranking_uses_test_count_as_tiebreaker(project_root: Path, tmp_path: Path) -> None: + optimizer = make_optimizer(project_root) + funcs = [make_func(name, project_root) for name in ("foo", "bar", "baz")] + trace_file = tmp_path / "trace.db" + trace_file.touch() + + ranked_functions = [funcs[0], funcs[1], funcs[2]] + addressable_times = {"foo": 100.0, "bar": 100.0, "baz": 5.0} + function_to_tests: dict[str, set[FunctionCalledInTest]] = { + funcs[0].qualified_name_with_modules_from_root(project_root): { + make_test(TestType.EXISTING_UNIT_TEST, "test_one") + }, + funcs[1].qualified_name_with_modules_from_root(project_root): { + make_test(TestType.EXISTING_UNIT_TEST, "test_one"), + make_test(TestType.EXISTING_UNIT_TEST, "test_two"), + make_test(TestType.EXISTING_UNIT_TEST, "test_three"), + }, + } + + class FakeRanker: + def __init__(self, _trace_file: Path) -> None: + pass + + def rank_functions(self, _functions: list[FunctionToOptimize]) -> list[FunctionToOptimize]: + return ranked_functions + + def get_function_addressable_time(self, function: FunctionToOptimize) -> float: + return addressable_times[function.function_name] + + with patch("codeflash.benchmarking.function_ranker.FunctionRanker", FakeRanker): + ranked = optimizer.rank_all_functions_globally( + {project_root / "mod.py": funcs}, trace_file, function_to_tests=function_to_tests + ) + + assert [func.function_name for _, func in ranked] == ["bar", "foo", "baz"] + + +def test_dependency_count_ranking_keeps_callee_count_primary(project_root: Path) -> None: + optimizer = make_optimizer(project_root) + funcs = [make_func(name, project_root) for name in ("foo", "bar")] + function_to_tests: dict[str, set[FunctionCalledInTest]] = { + funcs[1].qualified_name_with_modules_from_root(project_root): { + make_test(TestType.EXISTING_UNIT_TEST, "test_one"), + make_test(TestType.EXISTING_UNIT_TEST, "test_two"), + make_test(TestType.EXISTING_UNIT_TEST, "test_three"), + } + } + + class FakeResolver: + def count_callees_per_function(self, _mapping: dict[Path, set[str]]) -> dict[tuple[Path, str], int]: + return { + (project_root / "mod.py", "foo"): 5, + (project_root / "mod.py", "bar"): 1, + } + + ranked = optimizer.rank_by_dependency_count( + [(project_root / "mod.py", funcs[0]), (project_root / "mod.py", funcs[1])], + FakeResolver(), + function_to_tests=function_to_tests, + ) + + assert [func.function_name for _, func in ranked] == ["foo", "bar"] + + +def test_dependency_count_ranking_uses_test_count_as_tiebreaker(project_root: Path) -> None: + optimizer = make_optimizer(project_root) + funcs = [make_func(name, project_root) for name in ("foo", "bar")] + function_to_tests: dict[str, set[FunctionCalledInTest]] = { + funcs[0].qualified_name_with_modules_from_root(project_root): { + make_test(TestType.EXISTING_UNIT_TEST, "test_one") + }, + funcs[1].qualified_name_with_modules_from_root(project_root): { + make_test(TestType.EXISTING_UNIT_TEST, "test_one"), + make_test(TestType.EXISTING_UNIT_TEST, "test_two"), + make_test(TestType.EXISTING_UNIT_TEST, "test_three"), + }, + } + + class FakeResolver: + def count_callees_per_function(self, _mapping: dict[Path, set[str]]) -> dict[tuple[Path, str], int]: + return { + (project_root / "mod.py", "foo"): 2, + (project_root / "mod.py", "bar"): 2, + } + + ranked = optimizer.rank_by_dependency_count( + [(project_root / "mod.py", funcs[0]), (project_root / "mod.py", funcs[1])], + FakeResolver(), + function_to_tests=function_to_tests, + ) + + assert [func.function_name for _, func in ranked] == ["bar", "foo"]