mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix ranking boost ordering and statement helper extraction
This commit is contained in:
parent
ae188ec632
commit
cee12fe430
4 changed files with 191 additions and 8 deletions
|
|
@ -42,7 +42,7 @@ def _init_index_worker(project_root: str) -> None:
|
|||
def _resolve_definitions(ref: Name) -> list[Name]:
|
||||
try:
|
||||
inferred = ref.infer()
|
||||
valid = [d for d in inferred if d.type in ("function", "class")]
|
||||
valid = [d for d in inferred if d.type in ("function", "class", "statement")]
|
||||
if valid:
|
||||
return valid
|
||||
except Exception:
|
||||
|
|
@ -69,7 +69,7 @@ def _is_valid_definition(definition: Name, caller_qualified_name: str, project_r
|
|||
if not definition.full_name or not definition.full_name.startswith(definition.module_name):
|
||||
return False
|
||||
|
||||
if definition.type not in ("function", "class"):
|
||||
if definition.type not in ("function", "class", "statement"):
|
||||
return False
|
||||
|
||||
try:
|
||||
|
|
@ -164,6 +164,20 @@ def _analyze_file(file_path: Path, jedi_project: object, project_root_str: str)
|
|||
definition.get_line_code(),
|
||||
)
|
||||
)
|
||||
elif definition.type == "statement":
|
||||
callee_qn = get_qualified_name(definition.module_name, definition.full_name)
|
||||
if len(callee_qn.split(".")) > 2:
|
||||
continue
|
||||
edges.add(
|
||||
(
|
||||
*edge_base,
|
||||
callee_qn,
|
||||
definition.full_name,
|
||||
definition.name,
|
||||
definition.type,
|
||||
definition.get_line_code(),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -382,22 +382,30 @@ class Optimizer:
|
|||
# Use a tuple of unique identifiers as the key
|
||||
key: tuple[Path, str, int | None] = (func.file_path, func.qualified_name, func.starting_line)
|
||||
func_to_file_map[key] = file_path
|
||||
globally_ranked = []
|
||||
for func in ranked_functions:
|
||||
ranked_with_metadata: list[tuple[Path, FunctionToOptimize, float, int]] = []
|
||||
for rank_index, func in enumerate(ranked_functions):
|
||||
key = (func.file_path, func.qualified_name, func.starting_line)
|
||||
file_path = func_to_file_map.get(key)
|
||||
if file_path:
|
||||
globally_ranked.append((file_path, func))
|
||||
ranked_with_metadata.append(
|
||||
(file_path, func, ranker.get_function_addressable_time(func), rank_index)
|
||||
)
|
||||
|
||||
if function_to_tests:
|
||||
from codeflash.discovery.discover_unit_tests import existing_unit_test_count
|
||||
|
||||
globally_ranked.sort(
|
||||
ranked_with_metadata.sort(
|
||||
key=lambda item: (
|
||||
0 if existing_unit_test_count(item[1], self.args.project_root, function_to_tests) > 0 else 1
|
||||
-item[2],
|
||||
-existing_unit_test_count(item[1], self.args.project_root, function_to_tests),
|
||||
item[3],
|
||||
)
|
||||
)
|
||||
|
||||
globally_ranked = [
|
||||
(file_path, func) for file_path, func, _addressable_time, _rank_index in ranked_with_metadata
|
||||
]
|
||||
|
||||
console.rule()
|
||||
logger.info(
|
||||
f"Globally ranked {len(ranked_functions)} functions by addressable time "
|
||||
|
|
@ -433,8 +441,8 @@ class Optimizer:
|
|||
ranked = sorted(
|
||||
enumerate(all_functions),
|
||||
key=lambda x: (
|
||||
-existing_unit_test_count(x[1][1], self.args.project_root, function_to_tests),
|
||||
-callee_counts.get((x[1][0], x[1][1].qualified_name), 0),
|
||||
-existing_unit_test_count(x[1][1], self.args.project_root, function_to_tests),
|
||||
x[0],
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -430,6 +430,26 @@ def caller():
|
|||
finally:
|
||||
cg.close()
|
||||
|
||||
def test_get_callees_includes_statement_dependencies(self, project: Path, db_path: Path) -> None:
|
||||
write_file(
|
||||
project,
|
||||
"mod.py",
|
||||
"""\
|
||||
X = 1
|
||||
|
||||
def caller():
|
||||
return X + 1
|
||||
""",
|
||||
)
|
||||
cg = ReferenceGraph(project, db_path=db_path)
|
||||
try:
|
||||
_, function_sources = cg.get_callees({project / "mod.py": {"caller"}})
|
||||
assert [(source.qualified_name, source.definition_type) for source in function_sources] == [
|
||||
("X", "statement")
|
||||
]
|
||||
finally:
|
||||
cg.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CalleeMetadata unit tests
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -8,6 +10,7 @@ from codeflash.discovery.discover_unit_tests import existing_unit_test_count
|
|||
from codeflash.models.function_types import FunctionToOptimize
|
||||
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile
|
||||
from codeflash.models.test_type import TestType
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
|
||||
def make_func(name: str, project_root: Path) -> FunctionToOptimize:
|
||||
|
|
@ -23,6 +26,16 @@ def make_test(test_type: TestType, test_name: str = "test_something") -> Functio
|
|||
)
|
||||
|
||||
|
||||
def make_optimizer(project_root: Path) -> Optimizer:
|
||||
def _noop_display_global_ranking(*_args: object, **_kwargs: object) -> None:
|
||||
return None
|
||||
|
||||
optimizer = Optimizer.__new__(Optimizer)
|
||||
optimizer.args = Namespace(project_root=project_root)
|
||||
optimizer.display_global_ranking = _noop_display_global_ranking
|
||||
return optimizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_root(tmp_path: Path) -> Path:
|
||||
root = tmp_path / "project"
|
||||
|
|
@ -134,3 +147,131 @@ def test_parametrized_tests_deduplication(project_root: Path) -> None:
|
|||
}
|
||||
}
|
||||
assert existing_unit_test_count(func, project_root, tests) == 2
|
||||
|
||||
|
||||
def test_trace_ranking_keeps_addressable_time_primary_over_test_count(project_root: Path, tmp_path: Path) -> None:
|
||||
optimizer = make_optimizer(project_root)
|
||||
funcs = [make_func(name, project_root) for name in ("foo", "bar", "baz")]
|
||||
trace_file = tmp_path / "trace.db"
|
||||
trace_file.touch()
|
||||
|
||||
ranked_functions = [funcs[0], funcs[1], funcs[2]]
|
||||
addressable_times = {"foo": 100.0, "bar": 20.0, "baz": 5.0}
|
||||
function_to_tests: dict[str, set[FunctionCalledInTest]] = {
|
||||
funcs[1].qualified_name_with_modules_from_root(project_root): {
|
||||
make_test(TestType.EXISTING_UNIT_TEST, "test_one"),
|
||||
make_test(TestType.EXISTING_UNIT_TEST, "test_two"),
|
||||
make_test(TestType.EXISTING_UNIT_TEST, "test_three"),
|
||||
}
|
||||
}
|
||||
|
||||
class FakeRanker:
|
||||
def __init__(self, _trace_file: Path) -> None:
|
||||
pass
|
||||
|
||||
def rank_functions(self, _functions: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
|
||||
return ranked_functions
|
||||
|
||||
def get_function_addressable_time(self, function: FunctionToOptimize) -> float:
|
||||
return addressable_times[function.function_name]
|
||||
|
||||
with patch("codeflash.benchmarking.function_ranker.FunctionRanker", FakeRanker):
|
||||
ranked = optimizer.rank_all_functions_globally(
|
||||
{project_root / "mod.py": funcs}, trace_file, function_to_tests=function_to_tests
|
||||
)
|
||||
|
||||
assert [func.function_name for _, func in ranked] == ["foo", "bar", "baz"]
|
||||
|
||||
|
||||
def test_trace_ranking_uses_test_count_as_tiebreaker(project_root: Path, tmp_path: Path) -> None:
|
||||
optimizer = make_optimizer(project_root)
|
||||
funcs = [make_func(name, project_root) for name in ("foo", "bar", "baz")]
|
||||
trace_file = tmp_path / "trace.db"
|
||||
trace_file.touch()
|
||||
|
||||
ranked_functions = [funcs[0], funcs[1], funcs[2]]
|
||||
addressable_times = {"foo": 100.0, "bar": 100.0, "baz": 5.0}
|
||||
function_to_tests: dict[str, set[FunctionCalledInTest]] = {
|
||||
funcs[0].qualified_name_with_modules_from_root(project_root): {
|
||||
make_test(TestType.EXISTING_UNIT_TEST, "test_one")
|
||||
},
|
||||
funcs[1].qualified_name_with_modules_from_root(project_root): {
|
||||
make_test(TestType.EXISTING_UNIT_TEST, "test_one"),
|
||||
make_test(TestType.EXISTING_UNIT_TEST, "test_two"),
|
||||
make_test(TestType.EXISTING_UNIT_TEST, "test_three"),
|
||||
},
|
||||
}
|
||||
|
||||
class FakeRanker:
|
||||
def __init__(self, _trace_file: Path) -> None:
|
||||
pass
|
||||
|
||||
def rank_functions(self, _functions: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
|
||||
return ranked_functions
|
||||
|
||||
def get_function_addressable_time(self, function: FunctionToOptimize) -> float:
|
||||
return addressable_times[function.function_name]
|
||||
|
||||
with patch("codeflash.benchmarking.function_ranker.FunctionRanker", FakeRanker):
|
||||
ranked = optimizer.rank_all_functions_globally(
|
||||
{project_root / "mod.py": funcs}, trace_file, function_to_tests=function_to_tests
|
||||
)
|
||||
|
||||
assert [func.function_name for _, func in ranked] == ["bar", "foo", "baz"]
|
||||
|
||||
|
||||
def test_dependency_count_ranking_keeps_callee_count_primary(project_root: Path) -> None:
|
||||
optimizer = make_optimizer(project_root)
|
||||
funcs = [make_func(name, project_root) for name in ("foo", "bar")]
|
||||
function_to_tests: dict[str, set[FunctionCalledInTest]] = {
|
||||
funcs[1].qualified_name_with_modules_from_root(project_root): {
|
||||
make_test(TestType.EXISTING_UNIT_TEST, "test_one"),
|
||||
make_test(TestType.EXISTING_UNIT_TEST, "test_two"),
|
||||
make_test(TestType.EXISTING_UNIT_TEST, "test_three"),
|
||||
}
|
||||
}
|
||||
|
||||
class FakeResolver:
|
||||
def count_callees_per_function(self, _mapping: dict[Path, set[str]]) -> dict[tuple[Path, str], int]:
|
||||
return {
|
||||
(project_root / "mod.py", "foo"): 5,
|
||||
(project_root / "mod.py", "bar"): 1,
|
||||
}
|
||||
|
||||
ranked = optimizer.rank_by_dependency_count(
|
||||
[(project_root / "mod.py", funcs[0]), (project_root / "mod.py", funcs[1])],
|
||||
FakeResolver(),
|
||||
function_to_tests=function_to_tests,
|
||||
)
|
||||
|
||||
assert [func.function_name for _, func in ranked] == ["foo", "bar"]
|
||||
|
||||
|
||||
def test_dependency_count_ranking_uses_test_count_as_tiebreaker(project_root: Path) -> None:
|
||||
optimizer = make_optimizer(project_root)
|
||||
funcs = [make_func(name, project_root) for name in ("foo", "bar")]
|
||||
function_to_tests: dict[str, set[FunctionCalledInTest]] = {
|
||||
funcs[0].qualified_name_with_modules_from_root(project_root): {
|
||||
make_test(TestType.EXISTING_UNIT_TEST, "test_one")
|
||||
},
|
||||
funcs[1].qualified_name_with_modules_from_root(project_root): {
|
||||
make_test(TestType.EXISTING_UNIT_TEST, "test_one"),
|
||||
make_test(TestType.EXISTING_UNIT_TEST, "test_two"),
|
||||
make_test(TestType.EXISTING_UNIT_TEST, "test_three"),
|
||||
},
|
||||
}
|
||||
|
||||
class FakeResolver:
|
||||
def count_callees_per_function(self, _mapping: dict[Path, set[str]]) -> dict[tuple[Path, str], int]:
|
||||
return {
|
||||
(project_root / "mod.py", "foo"): 2,
|
||||
(project_root / "mod.py", "bar"): 2,
|
||||
}
|
||||
|
||||
ranked = optimizer.rank_by_dependency_count(
|
||||
[(project_root / "mod.py", funcs[0]), (project_root / "mod.py", funcs[1])],
|
||||
FakeResolver(),
|
||||
function_to_tests=function_to_tests,
|
||||
)
|
||||
|
||||
assert [func.function_name for _, func in ranked] == ["bar", "foo"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue