allow predict to be included

This commit is contained in:
Kevin Turcios 2025-06-30 18:02:48 -07:00
parent 7f9a609890
commit eb9e0c6558
3 changed files with 998 additions and 524 deletions

View file

@ -1,4 +1,5 @@
from concurrent.futures import ThreadPoolExecutor
from time import sleep
def funcA(number):
@ -46,12 +47,20 @@ class AlexNet:
class SimpleModel:
@staticmethod
def predict(data):
return [x * 2 for x in data]
result = []
sleep(10)
for i in range(500):
for x in data:
computation = 0
computation += x * i ** 2
result.append(computation)
return result
@classmethod
def create_default(cls):
return cls()
def test_models():
model = AlexNet(num_classes=10)
input_data = [1, 2, 3, 4, 5]

View file

@ -0,0 +1,138 @@
import pytest
from pathlib import Path
from unittest.mock import patch
from codeflash.benchmarking.function_ranker import FunctionRanker
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, find_all_functions_in_file
from codeflash.models.models import FunctionParent
@pytest.fixture
def trace_file():
return Path(__file__).parent.parent / "code_to_optimize/code_directories/simple_tracer_e2e/codeflash.trace.sqlite3"
@pytest.fixture
def workload_functions():
workloads_file = Path(__file__).parent.parent / "code_to_optimize/code_directories/simple_tracer_e2e/workload.py"
functions_dict = find_all_functions_in_file(workloads_file)
all_functions = []
for functions_list in functions_dict.values():
all_functions.extend(functions_list)
return all_functions
@pytest.fixture
def function_ranker(trace_file):
return FunctionRanker(trace_file)
def test_function_ranker_initialization(trace_file):
ranker = FunctionRanker(trace_file)
assert ranker.trace_file_path == trace_file
assert ranker._profile_stats is not None
assert isinstance(ranker._function_stats, dict)
def test_load_function_stats(function_ranker):
assert len(function_ranker._function_stats) > 0
# Check that funcA is loaded with expected structure
func_a_key = None
for key, stats in function_ranker._function_stats.items():
if stats["function_name"] == "funcA":
func_a_key = key
break
assert func_a_key is not None
func_a_stats = function_ranker._function_stats[func_a_key]
# Verify funcA stats structure
expected_keys = {
"filename", "function_name", "qualified_name", "class_name",
"line_number", "call_count", "own_time_ns", "cumulative_time_ns",
"time_in_callees_ns", "ttx_score"
}
assert set(func_a_stats.keys()) == expected_keys
# Verify funcA specific values
assert func_a_stats["function_name"] == "funcA"
assert func_a_stats["call_count"] == 1
assert func_a_stats["own_time_ns"] == 27000
assert func_a_stats["cumulative_time_ns"] == 1629000
def test_get_function_ttx_score(function_ranker, workload_functions):
func_a = None
for func in workload_functions:
if func.function_name == "funcA":
func_a = func
break
assert func_a is not None
ttx_score = function_ranker.get_function_ttx_score(func_a)
# Expected ttX score: own_time + (time_in_callees * call_count)
# = 27000 + ((1629000 - 27000) * 1) = 1629000
assert ttx_score == 1629000
def test_rank_functions(function_ranker, workload_functions):
ranked_functions = function_ranker.rank_functions(workload_functions)
assert len(ranked_functions) == len(workload_functions)
# Verify functions are sorted by ttX score in descending order
for i in range(len(ranked_functions) - 1):
current_score = function_ranker.get_function_ttx_score(ranked_functions[i])
next_score = function_ranker.get_function_ttx_score(ranked_functions[i + 1])
assert current_score >= next_score
def test_rerank_and_filter_functions(function_ranker, workload_functions):
filtered_ranked = function_ranker.rerank_and_filter_functions(workload_functions)
# Should filter out functions below importance threshold
assert len(filtered_ranked) <= len(workload_functions)
# funcA should pass the importance threshold (0.33% > 0.1%)
func_a_in_results = any(f.function_name == "funcA" for f in filtered_ranked)
assert func_a_in_results
def test_get_function_stats_summary(function_ranker, workload_functions):
func_a = None
for func in workload_functions:
if func.function_name == "funcA":
func_a = func
break
assert func_a is not None
stats = function_ranker.get_function_stats_summary(func_a)
assert stats is not None
assert stats["function_name"] == "funcA"
assert stats["own_time_ns"] == 27000
assert stats["cumulative_time_ns"] == 1629000
assert stats["ttx_score"] == 1629000
def test_importance_calculation(function_ranker):
total_program_time = sum(
s["own_time_ns"] for s in function_ranker._function_stats.values()
if s.get("own_time_ns", 0) > 0
)
func_a_stats = None
for stats in function_ranker._function_stats.values():
if stats["function_name"] == "funcA":
func_a_stats = stats
break
assert func_a_stats is not None
importance = func_a_stats["own_time_ns"] / total_program_time
# funcA importance should be approximately 0.33% (27000/8242000)
assert abs(importance - 0.00327) < 0.001

1373
uv.lock

File diff suppressed because it is too large Load diff