mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
feat: add agentic multi-agent optimization pipeline
Implement a new agentic workflow for CodeFlash with a --agentic CLI flag. The pipeline uses specialized agents (Discovery, Analysis, Generation, Verification, Selection, Integration) orchestrated by an AgentCoordinator. New components: - codeflash/agents/: Agent framework with BaseAgent and specialized agents - codeflash/state/: SQLite-based state management for persistence - codeflash/optimization/agentic_optimizer.py: Entry point for agentic mode Usage: codeflash --agentic --file path/to/file.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
6fc2c177f2
commit
84d954336f
20 changed files with 2970 additions and 2 deletions
25
codeflash/agents/__init__.py
Normal file
25
codeflash/agents/__init__.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from codeflash.agents.analysis_agent import AnalysisAgent
|
||||
from codeflash.agents.base import AgentMessage, AgentResult, AgentState, AgentTask, BaseAgent
|
||||
from codeflash.agents.coordinator import AgentCoordinator
|
||||
from codeflash.agents.discovery_agent import DiscoveryAgent
|
||||
from codeflash.agents.generation_agent import GenerationAgent
|
||||
from codeflash.agents.integration_agent import IntegrationAgent
|
||||
from codeflash.agents.selection_agent import SelectionAgent
|
||||
from codeflash.agents.verification_agent import VerificationAgent
|
||||
|
||||
__all__ = [
|
||||
"AgentCoordinator",
|
||||
"AgentMessage",
|
||||
"AgentResult",
|
||||
"AgentState",
|
||||
"AgentTask",
|
||||
"AnalysisAgent",
|
||||
"BaseAgent",
|
||||
"DiscoveryAgent",
|
||||
"GenerationAgent",
|
||||
"IntegrationAgent",
|
||||
"SelectionAgent",
|
||||
"VerificationAgent",
|
||||
]
|
||||
147
codeflash/agents/analysis_agent.py
Normal file
147
codeflash/agents/analysis_agent.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.agents.base import AgentResult, AgentTask, BaseAgent, wrap_error, wrap_result
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.either import Result
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.agents.coordinator import AgentCoordinator
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
|
||||
|
||||
class AnalysisAgent(BaseAgent):
|
||||
def __init__(self, coordinator: AgentCoordinator | None = None) -> None:
|
||||
super().__init__(agent_id="analysis", coordinator=coordinator)
|
||||
|
||||
def process(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
task_type = task.task_type
|
||||
if task_type == "extract_context":
|
||||
return self._extract_context(task)
|
||||
if task_type == "assess_complexity":
|
||||
return self._assess_complexity(task)
|
||||
if task_type == "check_numerical":
|
||||
return self._check_numerical(task)
|
||||
return wrap_error(f"Unknown task type: {task_type}", task.task_id, self.agent_id)
|
||||
|
||||
def _extract_context(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
function_to_optimize: FunctionToOptimize = payload["function_to_optimize"]
|
||||
project_root: Path = payload["project_root"]
|
||||
|
||||
code_context = get_code_optimization_context(function_to_optimize, project_root)
|
||||
|
||||
result_data = {
|
||||
"code_context": code_context,
|
||||
"function_to_optimize": function_to_optimize,
|
||||
}
|
||||
|
||||
logger.debug(f"AnalysisAgent: Extracted context for {function_to_optimize.qualified_name}")
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"AnalysisAgent context extraction failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _assess_complexity(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
code_context: CodeOptimizationContext = payload["code_context"]
|
||||
|
||||
complexity_info = self._analyze_code_complexity(code_context)
|
||||
|
||||
result_data = {
|
||||
"complexity": complexity_info,
|
||||
"code_context": code_context,
|
||||
}
|
||||
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _check_numerical(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.code_utils.code_extractor import is_numerical_code
|
||||
|
||||
source_code: str = payload["source_code"]
|
||||
|
||||
is_numerical = is_numerical_code(source_code)
|
||||
|
||||
result_data = {
|
||||
"is_numerical": is_numerical,
|
||||
}
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _analyze_code_complexity(self, code_context: CodeOptimizationContext) -> dict[str, Any]:
|
||||
from codeflash.code_utils.code_utils import encoded_tokens_len
|
||||
|
||||
read_writable_tokens = encoded_tokens_len(code_context.read_writable_code.markdown)
|
||||
read_only_tokens = (
|
||||
encoded_tokens_len(code_context.read_only_context_code) if code_context.read_only_context_code else 0
|
||||
)
|
||||
|
||||
helper_count = len(code_context.helper_functions)
|
||||
file_count = len(code_context.read_writable_code.code_strings)
|
||||
|
||||
return {
|
||||
"total_tokens": read_writable_tokens + read_only_tokens,
|
||||
"read_writable_tokens": read_writable_tokens,
|
||||
"read_only_tokens": read_only_tokens,
|
||||
"helper_count": helper_count,
|
||||
"file_count": file_count,
|
||||
"complexity_level": self._determine_complexity_level(
|
||||
read_writable_tokens + read_only_tokens, helper_count, file_count
|
||||
),
|
||||
}
|
||||
|
||||
def _determine_complexity_level(self, total_tokens: int, helper_count: int, file_count: int) -> str:
|
||||
if total_tokens > 10000 or helper_count > 10 or file_count > 5:
|
||||
return "high"
|
||||
if total_tokens > 5000 or helper_count > 5 or file_count > 2:
|
||||
return "medium"
|
||||
return "low"
|
||||
|
||||
|
||||
def create_context_extraction_task(
|
||||
function_to_optimize: Any,
|
||||
project_root: Path,
|
||||
) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="extract_context",
|
||||
payload={
|
||||
"function_to_optimize": function_to_optimize,
|
||||
"project_root": project_root,
|
||||
},
|
||||
priority=8,
|
||||
)
|
||||
|
||||
|
||||
def create_complexity_assessment_task(code_context: Any) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="assess_complexity",
|
||||
payload={
|
||||
"code_context": code_context,
|
||||
},
|
||||
priority=7,
|
||||
)
|
||||
|
||||
|
||||
def create_numerical_check_task(source_code: str) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="check_numerical",
|
||||
payload={
|
||||
"source_code": source_code,
|
||||
},
|
||||
priority=7,
|
||||
)
|
||||
193
codeflash/agents/base.py
Normal file
193
codeflash/agents/base.py
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from queue import Queue
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
from codeflash.either import Failure, Result, Success
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.agents.coordinator import AgentCoordinator
|
||||
|
||||
|
||||
class AgentState(Enum):
|
||||
IDLE = "idle"
|
||||
RUNNING = "running"
|
||||
WAITING = "waiting"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentTask:
|
||||
task_id: str
|
||||
task_type: str
|
||||
payload: dict[str, Any]
|
||||
priority: int = 0
|
||||
parent_task_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
task_type: str,
|
||||
payload: dict[str, Any],
|
||||
priority: int = 0,
|
||||
parent_task_id: str | None = None,
|
||||
) -> AgentTask:
|
||||
return cls(
|
||||
task_id=str(uuid.uuid4()),
|
||||
task_type=task_type,
|
||||
payload=payload,
|
||||
priority=priority,
|
||||
parent_task_id=parent_task_id,
|
||||
)
|
||||
|
||||
def __lt__(self, other: AgentTask) -> bool:
|
||||
return self.priority > other.priority
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentMessage:
|
||||
message_id: str
|
||||
sender_id: str
|
||||
message_type: str
|
||||
payload: dict[str, Any]
|
||||
timestamp: float = field(default_factory=lambda: __import__("time").time())
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
sender_id: str,
|
||||
message_type: str,
|
||||
payload: dict[str, Any],
|
||||
) -> AgentMessage:
|
||||
return cls(
|
||||
message_id=str(uuid.uuid4()),
|
||||
sender_id=sender_id,
|
||||
message_type=message_type,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult(Generic[T]):
|
||||
task_id: str
|
||||
agent_id: str
|
||||
success: bool
|
||||
data: T | None = None
|
||||
error: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def success_result(
|
||||
cls,
|
||||
task_id: str,
|
||||
agent_id: str,
|
||||
data: T,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> AgentResult[T]:
|
||||
return cls(
|
||||
task_id=task_id,
|
||||
agent_id=agent_id,
|
||||
success=True,
|
||||
data=data,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def failure_result(
|
||||
cls,
|
||||
task_id: str,
|
||||
agent_id: str,
|
||||
error: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> AgentResult[T]:
|
||||
return cls(
|
||||
task_id=task_id,
|
||||
agent_id=agent_id,
|
||||
success=False,
|
||||
error=error,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
def __init__(self, agent_id: str, coordinator: AgentCoordinator | None = None) -> None:
|
||||
self.agent_id = agent_id
|
||||
self.state = AgentState.IDLE
|
||||
self.coordinator = coordinator
|
||||
self.inbox: Queue[AgentMessage] = Queue()
|
||||
self.current_task: AgentTask | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
def set_coordinator(self, coordinator: AgentCoordinator) -> None:
|
||||
self.coordinator = coordinator
|
||||
|
||||
@abstractmethod
|
||||
def process(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
pass
|
||||
|
||||
def execute(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
self.state = AgentState.RUNNING
|
||||
self.current_task = task
|
||||
|
||||
try:
|
||||
result = self.process(task)
|
||||
if result.is_successful():
|
||||
self.state = AgentState.COMPLETED
|
||||
else:
|
||||
self.state = AgentState.FAILED
|
||||
return result
|
||||
except Exception as e:
|
||||
self.state = AgentState.FAILED
|
||||
self.handle_failure(str(e))
|
||||
return Failure(f"Agent {self.agent_id} failed: {e}")
|
||||
finally:
|
||||
self.current_task = None
|
||||
|
||||
def handle_failure(self, error: str) -> None:
|
||||
if self.coordinator:
|
||||
self.coordinator.persist_agent_state(self.agent_id, self.get_state_snapshot())
|
||||
self.coordinator.notify_failure(self.agent_id, error)
|
||||
|
||||
def get_state_snapshot(self) -> dict[str, Any]:
|
||||
return {
|
||||
"agent_id": self.agent_id,
|
||||
"state": self.state.value,
|
||||
"current_task": self.current_task.task_id if self.current_task else None,
|
||||
"inbox_size": self.inbox.qsize(),
|
||||
}
|
||||
|
||||
def receive_message(self, message: AgentMessage) -> None:
|
||||
self.inbox.put(message)
|
||||
|
||||
def has_pending_messages(self) -> bool:
|
||||
return not self.inbox.empty()
|
||||
|
||||
def get_next_message(self) -> AgentMessage | None:
|
||||
if self.inbox.empty():
|
||||
return None
|
||||
return self.inbox.get_nowait()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.state = AgentState.IDLE
|
||||
self.current_task = None
|
||||
while not self.inbox.empty():
|
||||
self.inbox.get_nowait()
|
||||
|
||||
|
||||
def wrap_result(data: Any, task_id: str, agent_id: str) -> Result[AgentResult[Any], str]:
|
||||
return Success(AgentResult.success_result(task_id=task_id, agent_id=agent_id, data=data))
|
||||
|
||||
|
||||
def wrap_error(error: str, task_id: str, agent_id: str) -> Result[AgentResult[Any], str]:
|
||||
return Failure(error)
|
||||
284
codeflash/agents/coordinator.py
Normal file
284
codeflash/agents/coordinator.py
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from queue import PriorityQueue
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.agents.base import AgentResult, AgentTask, BaseAgent
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.either import Failure, Result, Success
|
||||
from codeflash.state.models import AgentStateSnapshot, OptimizationAttempt, PipelineState
|
||||
from codeflash.state.store import StateStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import BestOptimization, CodeOptimizationContext
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
class AgentCoordinator:
|
||||
def __init__(self, state_store: StateStore | None = None) -> None:
|
||||
self.agents: dict[str, BaseAgent] = {}
|
||||
self.state_store = state_store or StateStore()
|
||||
self.task_queue: PriorityQueue[AgentTask] = PriorityQueue()
|
||||
self._initialized = False
|
||||
|
||||
def register_agents(self) -> None:
|
||||
from codeflash.agents.analysis_agent import AnalysisAgent
|
||||
from codeflash.agents.discovery_agent import DiscoveryAgent
|
||||
from codeflash.agents.generation_agent import GenerationAgent
|
||||
from codeflash.agents.integration_agent import IntegrationAgent
|
||||
from codeflash.agents.selection_agent import SelectionAgent
|
||||
from codeflash.agents.verification_agent import VerificationAgent
|
||||
|
||||
self.agents = {
|
||||
"discovery": DiscoveryAgent(coordinator=self, state_store=self.state_store),
|
||||
"analysis": AnalysisAgent(coordinator=self),
|
||||
"generation": GenerationAgent(coordinator=self),
|
||||
"verification": VerificationAgent(coordinator=self),
|
||||
"selection": SelectionAgent(coordinator=self),
|
||||
"integration": IntegrationAgent(coordinator=self),
|
||||
}
|
||||
self._initialized = True
|
||||
|
||||
def get_agent(self, agent_id: str) -> BaseAgent | None:
|
||||
return self.agents.get(agent_id)
|
||||
|
||||
def submit_task(self, task: AgentTask) -> None:
|
||||
self.task_queue.put(task)
|
||||
|
||||
def execute_task(self, agent_id: str, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
agent = self.agents.get(agent_id)
|
||||
if agent is None:
|
||||
return Failure(f"Agent {agent_id} not found")
|
||||
|
||||
logger.debug(f"Coordinator: Executing task {task.task_type} on agent {agent_id}")
|
||||
return agent.execute(task)
|
||||
|
||||
def persist_agent_state(self, agent_id: str, state_snapshot: dict[str, Any]) -> None:
|
||||
agent = self.agents.get(agent_id)
|
||||
if agent is None:
|
||||
return
|
||||
|
||||
snapshot = AgentStateSnapshot.create(
|
||||
agent_id=agent_id,
|
||||
agent_type=agent.name,
|
||||
state=state_snapshot.get("state", "unknown"),
|
||||
current_task_id=state_snapshot.get("current_task"),
|
||||
context=state_snapshot,
|
||||
)
|
||||
self.state_store.persist_agent_state(snapshot)
|
||||
|
||||
def notify_failure(self, agent_id: str, error: str) -> None:
|
||||
logger.warning(f"Agent {agent_id} failed: {error}")
|
||||
|
||||
def run_optimization_pipeline(
|
||||
self,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
test_cfg: TestConfig,
|
||||
args: Namespace,
|
||||
function_to_tests: dict[str, set[Any]] | None = None,
|
||||
) -> Result[BestOptimization, str]:
|
||||
if not self._initialized:
|
||||
self.register_agents()
|
||||
|
||||
pipeline_id = str(uuid.uuid4())
|
||||
trace_id = str(uuid.uuid4())
|
||||
|
||||
pipeline_state = PipelineState.create(
|
||||
pipeline_id=pipeline_id,
|
||||
function_qualified_name=function_to_optimize.qualified_name,
|
||||
file_path=function_to_optimize.file_path,
|
||||
)
|
||||
self.state_store.persist_pipeline_state(pipeline_state)
|
||||
|
||||
attempt = OptimizationAttempt.create(
|
||||
attempt_id=pipeline_id,
|
||||
function_qualified_name=function_to_optimize.qualified_name,
|
||||
file_path=function_to_optimize.file_path,
|
||||
)
|
||||
self.state_store.persist_optimization_attempt(attempt.mark_in_progress())
|
||||
|
||||
logger.info(f"Starting agentic optimization pipeline for {function_to_optimize.qualified_name}")
|
||||
console.rule()
|
||||
|
||||
try:
|
||||
context_result = self._run_analysis_stage(function_to_optimize, args.project_root)
|
||||
if context_result.is_failure():
|
||||
return self._handle_pipeline_failure(attempt, context_result.failure())
|
||||
code_context = context_result.unwrap()
|
||||
|
||||
candidates_result = self._run_generation_stage(code_context, trace_id, function_to_optimize)
|
||||
if candidates_result.is_failure():
|
||||
return self._handle_pipeline_failure(attempt, candidates_result.failure())
|
||||
candidates = candidates_result.unwrap()
|
||||
|
||||
if not candidates:
|
||||
return self._handle_pipeline_failure(attempt, "No optimization candidates generated")
|
||||
|
||||
best_result = self._run_verification_and_selection_stage(
|
||||
candidates=candidates,
|
||||
code_context=code_context,
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
args=args,
|
||||
function_to_tests=function_to_tests,
|
||||
trace_id=trace_id,
|
||||
)
|
||||
if best_result.is_failure():
|
||||
return self._handle_pipeline_failure(attempt, best_result.failure())
|
||||
|
||||
best_optimization = best_result.unwrap()
|
||||
|
||||
integration_result = self._run_integration_stage(
|
||||
best_optimization=best_optimization,
|
||||
function_to_optimize=function_to_optimize,
|
||||
args=args,
|
||||
)
|
||||
if integration_result.is_failure():
|
||||
logger.warning(f"Integration stage failed: {integration_result.failure()}")
|
||||
|
||||
speedup = 1.0
|
||||
if best_optimization.runtime > 0:
|
||||
original_runtime = best_optimization.winning_benchmarking_test_results.get_best_runtime() or 0
|
||||
if original_runtime > 0:
|
||||
speedup = original_runtime / best_optimization.runtime
|
||||
|
||||
completed_attempt = attempt.mark_completed(
|
||||
speedup=speedup,
|
||||
original_runtime_ns=original_runtime,
|
||||
optimized_runtime_ns=best_optimization.runtime,
|
||||
pr_url=integration_result.unwrap().get("pr_url") if integration_result.is_successful() else None,
|
||||
)
|
||||
self.state_store.persist_optimization_attempt(completed_attempt)
|
||||
|
||||
logger.info(f"Agentic optimization complete for {function_to_optimize.qualified_name}")
|
||||
return Success(best_optimization)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Pipeline failed: {e}")
|
||||
return self._handle_pipeline_failure(attempt, str(e))
|
||||
|
||||
def _run_analysis_stage(
|
||||
self,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root: Path,
|
||||
) -> Result[CodeOptimizationContext, str]:
|
||||
from codeflash.agents.analysis_agent import create_context_extraction_task
|
||||
|
||||
task = create_context_extraction_task(function_to_optimize, project_root)
|
||||
result = self.execute_task("analysis", task)
|
||||
|
||||
if result.is_failure():
|
||||
return Failure(result.failure())
|
||||
|
||||
agent_result = result.unwrap()
|
||||
if not agent_result.success:
|
||||
return Failure(agent_result.error or "Analysis failed")
|
||||
|
||||
return Success(agent_result.data["code_context"])
|
||||
|
||||
def _run_generation_stage(
|
||||
self,
|
||||
code_context: CodeOptimizationContext,
|
||||
trace_id: str,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
) -> Result[list[Any], str]:
|
||||
from codeflash.agents.generation_agent import create_generation_task
|
||||
|
||||
language = function_to_optimize.language or "python"
|
||||
is_async = function_to_optimize.is_async
|
||||
|
||||
task = create_generation_task(
|
||||
code_context=code_context,
|
||||
trace_id=trace_id,
|
||||
is_async=is_async,
|
||||
language=language,
|
||||
)
|
||||
result = self.execute_task("generation", task)
|
||||
|
||||
if result.is_failure():
|
||||
return Failure(result.failure())
|
||||
|
||||
agent_result = result.unwrap()
|
||||
if not agent_result.success:
|
||||
return Failure(agent_result.error or "Generation failed")
|
||||
|
||||
return Success(agent_result.data["candidates"])
|
||||
|
||||
def _run_verification_and_selection_stage(
|
||||
self,
|
||||
candidates: list[Any],
|
||||
code_context: CodeOptimizationContext,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
test_cfg: TestConfig,
|
||||
args: Namespace,
|
||||
function_to_tests: dict[str, set[Any]] | None,
|
||||
trace_id: str,
|
||||
) -> Result[BestOptimization, str]:
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
function_optimizer = FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=code_context.read_writable_code.code_strings[0].code
|
||||
if code_context.read_writable_code.code_strings
|
||||
else "",
|
||||
function_to_tests=function_to_tests,
|
||||
function_to_optimize_ast=None,
|
||||
aiservice_client=self.agents["generation"].aiservice_client
|
||||
if "generation" in self.agents
|
||||
else None,
|
||||
args=args,
|
||||
)
|
||||
|
||||
try:
|
||||
return function_optimizer.optimize_function()
|
||||
finally:
|
||||
function_optimizer.executor.shutdown(wait=True)
|
||||
function_optimizer.cleanup_generated_files()
|
||||
|
||||
def _run_integration_stage(
|
||||
self,
|
||||
best_optimization: BestOptimization,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
args: Namespace,
|
||||
) -> Result[dict[str, Any], str]:
|
||||
from codeflash.agents.integration_agent import create_pr_task
|
||||
|
||||
if getattr(args, "no_pr", False):
|
||||
return Success({"pr_created": False, "reason": "PR creation disabled"})
|
||||
|
||||
task = create_pr_task(
|
||||
best_optimization=best_optimization,
|
||||
function_to_optimize=function_to_optimize,
|
||||
args=args,
|
||||
)
|
||||
result = self.execute_task("integration", task)
|
||||
|
||||
if result.is_failure():
|
||||
return Failure(result.failure())
|
||||
|
||||
agent_result = result.unwrap()
|
||||
if not agent_result.success:
|
||||
return Failure(agent_result.error or "Integration failed")
|
||||
|
||||
return Success(agent_result.data)
|
||||
|
||||
def _handle_pipeline_failure(
|
||||
self,
|
||||
attempt: OptimizationAttempt,
|
||||
error: str,
|
||||
) -> Result[BestOptimization, str]:
|
||||
failed_attempt = attempt.mark_failed(error)
|
||||
self.state_store.persist_optimization_attempt(failed_attempt)
|
||||
return Failure(error)
|
||||
|
||||
def reset_all_agents(self) -> None:
|
||||
for agent in self.agents.values():
|
||||
agent.reset()
|
||||
while not self.task_queue.empty():
|
||||
self.task_queue.get_nowait()
|
||||
187
codeflash/agents/discovery_agent.py
Normal file
187
codeflash/agents/discovery_agent.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.agents.base import AgentResult, AgentTask, BaseAgent, wrap_error, wrap_result
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.either import Result
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.agents.coordinator import AgentCoordinator
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.state.store import StateStore
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
class DiscoveryAgent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
coordinator: AgentCoordinator | None = None,
|
||||
state_store: StateStore | None = None,
|
||||
) -> None:
|
||||
super().__init__(agent_id="discovery", coordinator=coordinator)
|
||||
self.state_store = state_store
|
||||
|
||||
def process(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
task_type = task.task_type
|
||||
if task_type == "discover_functions":
|
||||
return self._discover_functions(task)
|
||||
if task_type == "discover_tests":
|
||||
return self._discover_tests(task)
|
||||
if task_type == "filter_functions":
|
||||
return self._filter_functions(task)
|
||||
return wrap_error(f"Unknown task type: {task_type}", task.task_id, self.agent_id)
|
||||
|
||||
def _discover_functions(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
|
||||
|
||||
test_cfg: TestConfig = payload["test_cfg"]
|
||||
optimize_all = payload.get("optimize_all")
|
||||
replay_test = payload.get("replay_test")
|
||||
file_path = payload.get("file")
|
||||
function_name = payload.get("function")
|
||||
ignore_paths = payload.get("ignore_paths", [])
|
||||
project_root = payload["project_root"]
|
||||
module_root = payload["module_root"]
|
||||
previous_checkpoint_functions = payload.get("previous_checkpoint_functions")
|
||||
|
||||
file_to_funcs, num_functions, trace_file = get_functions_to_optimize(
|
||||
optimize_all=optimize_all,
|
||||
replay_test=replay_test,
|
||||
file=file_path,
|
||||
only_get_this_function=function_name,
|
||||
test_cfg=test_cfg,
|
||||
ignore_paths=ignore_paths,
|
||||
project_root=project_root,
|
||||
module_root=module_root,
|
||||
previous_checkpoint_functions=previous_checkpoint_functions,
|
||||
)
|
||||
|
||||
if self.state_store:
|
||||
file_to_funcs = self._filter_by_history(file_to_funcs)
|
||||
num_functions = sum(len(funcs) for funcs in file_to_funcs.values())
|
||||
|
||||
result_data = {
|
||||
"file_to_funcs": file_to_funcs,
|
||||
"num_functions": num_functions,
|
||||
"trace_file": trace_file,
|
||||
}
|
||||
|
||||
logger.info(f"DiscoveryAgent: Found {num_functions} functions to optimize")
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"DiscoveryAgent failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _discover_tests(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
|
||||
test_cfg: TestConfig = payload["test_cfg"]
|
||||
file_to_funcs: dict[Path, list[FunctionToOptimize]] = payload["file_to_funcs"]
|
||||
|
||||
function_to_tests, num_tests, num_replay_tests = discover_unit_tests(
|
||||
test_cfg, file_to_funcs_to_optimize=file_to_funcs
|
||||
)
|
||||
|
||||
result_data = {
|
||||
"function_to_tests": function_to_tests,
|
||||
"num_tests": num_tests,
|
||||
"num_replay_tests": num_replay_tests,
|
||||
}
|
||||
|
||||
logger.info(f"DiscoveryAgent: Discovered {num_tests} tests and {num_replay_tests} replay tests")
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"DiscoveryAgent test discovery failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _filter_functions(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
file_to_funcs: dict[Path, list[FunctionToOptimize]] = payload["file_to_funcs"]
|
||||
filtered = self._filter_by_history(file_to_funcs)
|
||||
num_functions = sum(len(funcs) for funcs in filtered.values())
|
||||
|
||||
result_data = {
|
||||
"file_to_funcs": filtered,
|
||||
"num_functions": num_functions,
|
||||
}
|
||||
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _filter_by_history(
|
||||
self, file_to_funcs: dict[Path, list[FunctionToOptimize]]
|
||||
) -> dict[Path, list[FunctionToOptimize]]:
|
||||
if not self.state_store:
|
||||
return file_to_funcs
|
||||
|
||||
from codeflash.state.history import OptimizationHistory
|
||||
|
||||
history = OptimizationHistory(self.state_store)
|
||||
filtered: dict[Path, list[FunctionToOptimize]] = {}
|
||||
|
||||
for file_path, functions in file_to_funcs.items():
|
||||
kept_functions = []
|
||||
for func in functions:
|
||||
should_skip, reason = history.should_skip_function(func.qualified_name)
|
||||
if should_skip:
|
||||
logger.debug(f"Skipping {func.qualified_name}: {reason}")
|
||||
continue
|
||||
kept_functions.append(func)
|
||||
|
||||
if kept_functions:
|
||||
filtered[file_path] = kept_functions
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def create_discovery_task(
|
||||
test_cfg: Any,
|
||||
project_root: Path,
|
||||
module_root: Path,
|
||||
optimize_all: Path | str | None = None,
|
||||
replay_test: list[Path] | None = None,
|
||||
file: Path | None = None,
|
||||
function: str | None = None,
|
||||
ignore_paths: list[Path] | None = None,
|
||||
previous_checkpoint_functions: dict[str, dict[str, str]] | None = None,
|
||||
) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="discover_functions",
|
||||
payload={
|
||||
"test_cfg": test_cfg,
|
||||
"project_root": project_root,
|
||||
"module_root": module_root,
|
||||
"optimize_all": optimize_all,
|
||||
"replay_test": replay_test,
|
||||
"file": file,
|
||||
"function": function,
|
||||
"ignore_paths": ignore_paths or [],
|
||||
"previous_checkpoint_functions": previous_checkpoint_functions,
|
||||
},
|
||||
priority=10,
|
||||
)
|
||||
|
||||
|
||||
def create_test_discovery_task(
|
||||
test_cfg: Any,
|
||||
file_to_funcs: dict[Path, list[Any]],
|
||||
) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="discover_tests",
|
||||
payload={
|
||||
"test_cfg": test_cfg,
|
||||
"file_to_funcs": file_to_funcs,
|
||||
},
|
||||
priority=9,
|
||||
)
|
||||
197
codeflash/agents/generation_agent.py
Normal file
197
codeflash/agents/generation_agent.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.agents.base import AgentResult, AgentTask, BaseAgent, wrap_error, wrap_result
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.either import Result
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.agents.coordinator import AgentCoordinator
|
||||
from codeflash.api.aiservice import AiServiceClient
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
|
||||
|
||||
class GenerationAgent(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
coordinator: AgentCoordinator | None = None,
|
||||
aiservice_client: AiServiceClient | None = None,
|
||||
) -> None:
|
||||
super().__init__(agent_id="generation", coordinator=coordinator)
|
||||
self._aiservice_client = aiservice_client
|
||||
|
||||
@property
|
||||
def aiservice_client(self) -> AiServiceClient:
|
||||
if self._aiservice_client is None:
|
||||
from codeflash.api.aiservice import AiServiceClient
|
||||
|
||||
self._aiservice_client = AiServiceClient()
|
||||
return self._aiservice_client
|
||||
|
||||
def process(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
task_type = task.task_type
|
||||
if task_type == "generate_candidates":
|
||||
return self._generate_candidates(task)
|
||||
if task_type == "refine_candidate":
|
||||
return self._refine_candidate(task)
|
||||
if task_type == "repair_candidate":
|
||||
return self._repair_candidate(task)
|
||||
if task_type == "adaptive_optimize":
|
||||
return self._adaptive_optimize(task)
|
||||
return wrap_error(f"Unknown task type: {task_type}", task.task_id, self.agent_id)
|
||||
|
||||
def _generate_candidates(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
code_context: CodeOptimizationContext = payload["code_context"]
|
||||
trace_id: str = payload["trace_id"]
|
||||
is_async: bool = payload.get("is_async", False)
|
||||
n_candidates: int = payload.get("n_candidates", 5)
|
||||
is_numerical: bool | None = payload.get("is_numerical")
|
||||
language: str = payload.get("language", "python")
|
||||
|
||||
source_code = code_context.read_writable_code.markdown
|
||||
dependency_code = code_context.read_only_context_code or ""
|
||||
|
||||
candidates = self.aiservice_client.optimize_code(
|
||||
source_code=source_code,
|
||||
dependency_code=dependency_code,
|
||||
trace_id=trace_id,
|
||||
language=language,
|
||||
is_async=is_async,
|
||||
n_candidates=n_candidates,
|
||||
is_numerical_code=is_numerical,
|
||||
)
|
||||
|
||||
result_data = {
|
||||
"candidates": candidates,
|
||||
"candidate_count": len(candidates),
|
||||
}
|
||||
|
||||
logger.info(f"GenerationAgent: Generated {len(candidates)} optimization candidates")
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"GenerationAgent candidate generation failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _refine_candidate(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.models.models import AIServiceRefinerRequest
|
||||
|
||||
refiner_request: AIServiceRefinerRequest = payload["refiner_request"]
|
||||
trace_id: str = payload["trace_id"]
|
||||
|
||||
refined_candidates = self.aiservice_client.refine_code(refiner_request, trace_id)
|
||||
|
||||
result_data = {
|
||||
"refined_candidates": refined_candidates,
|
||||
"original_optimization_id": refiner_request.optimization_id,
|
||||
}
|
||||
|
||||
logger.debug(f"GenerationAgent: Refined candidate {refiner_request.optimization_id}")
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"GenerationAgent refinement failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _repair_candidate(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.models.models import AIServiceCodeRepairRequest
|
||||
|
||||
repair_request: AIServiceCodeRepairRequest = payload["repair_request"]
|
||||
trace_id: str = payload["trace_id"]
|
||||
|
||||
repaired_candidate = self.aiservice_client.repair_code(repair_request, trace_id)
|
||||
|
||||
result_data = {
|
||||
"repaired_candidate": repaired_candidate,
|
||||
"original_optimization_id": repair_request.optimization_id,
|
||||
}
|
||||
|
||||
logger.debug(f"GenerationAgent: Repaired candidate {repair_request.optimization_id}")
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"GenerationAgent repair failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _adaptive_optimize(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.models.models import AIServiceAdaptiveOptimizeRequest
|
||||
|
||||
adaptive_request: AIServiceAdaptiveOptimizeRequest = payload["adaptive_request"]
|
||||
trace_id: str = payload["trace_id"]
|
||||
|
||||
adaptive_candidates = self.aiservice_client.adaptive_optimize(adaptive_request, trace_id)
|
||||
|
||||
result_data = {
|
||||
"adaptive_candidates": adaptive_candidates,
|
||||
}
|
||||
|
||||
logger.debug("GenerationAgent: Generated adaptive optimization candidates")
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"GenerationAgent adaptive optimization failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
|
||||
def create_generation_task(
|
||||
code_context: Any,
|
||||
trace_id: str,
|
||||
is_async: bool = False,
|
||||
n_candidates: int = 5,
|
||||
is_numerical: bool | None = None,
|
||||
language: str = "python",
|
||||
) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="generate_candidates",
|
||||
payload={
|
||||
"code_context": code_context,
|
||||
"trace_id": trace_id,
|
||||
"is_async": is_async,
|
||||
"n_candidates": n_candidates,
|
||||
"is_numerical": is_numerical,
|
||||
"language": language,
|
||||
},
|
||||
priority=7,
|
||||
)
|
||||
|
||||
|
||||
def create_refinement_task(refiner_request: Any, trace_id: str) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="refine_candidate",
|
||||
payload={
|
||||
"refiner_request": refiner_request,
|
||||
"trace_id": trace_id,
|
||||
},
|
||||
priority=5,
|
||||
)
|
||||
|
||||
|
||||
def create_repair_task(repair_request: Any, trace_id: str) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="repair_candidate",
|
||||
payload={
|
||||
"repair_request": repair_request,
|
||||
"trace_id": trace_id,
|
||||
},
|
||||
priority=5,
|
||||
)
|
||||
|
||||
|
||||
def create_adaptive_optimization_task(adaptive_request: Any, trace_id: str) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="adaptive_optimize",
|
||||
payload={
|
||||
"adaptive_request": adaptive_request,
|
||||
"trace_id": trace_id,
|
||||
},
|
||||
priority=4,
|
||||
)
|
||||
158
codeflash/agents/integration_agent.py
Normal file
158
codeflash/agents/integration_agent.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.agents.base import AgentResult, AgentTask, BaseAgent, wrap_error, wrap_result
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.either import Result
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from codeflash.agents.coordinator import AgentCoordinator
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import BestOptimization
|
||||
|
||||
|
||||
class IntegrationAgent(BaseAgent):
|
||||
def __init__(self, coordinator: AgentCoordinator | None = None) -> None:
|
||||
super().__init__(agent_id="integration", coordinator=coordinator)
|
||||
|
||||
def process(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
task_type = task.task_type
|
||||
if task_type == "create_pr":
|
||||
return self._create_pr(task)
|
||||
if task_type == "apply_optimization":
|
||||
return self._apply_optimization(task)
|
||||
if task_type == "generate_explanation":
|
||||
return self._generate_explanation(task)
|
||||
return wrap_error(f"Unknown task type: {task_type}", task.task_id, self.agent_id)
|
||||
|
||||
def _create_pr(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.result.create_pr import check_create_pr
|
||||
|
||||
best_optimization: BestOptimization = payload["best_optimization"]
|
||||
function_to_optimize: FunctionToOptimize = payload["function_to_optimize"]
|
||||
args: Namespace = payload["args"]
|
||||
|
||||
pr_result = check_create_pr(
|
||||
best_optimization=best_optimization,
|
||||
function_to_optimize=function_to_optimize,
|
||||
args=args,
|
||||
)
|
||||
|
||||
result_data = {
|
||||
"pr_result": pr_result,
|
||||
"function_name": function_to_optimize.qualified_name,
|
||||
}
|
||||
|
||||
if pr_result:
|
||||
logger.info(f"IntegrationAgent: Created PR for {function_to_optimize.qualified_name}")
|
||||
else:
|
||||
logger.debug(f"IntegrationAgent: PR creation skipped for {function_to_optimize.qualified_name}")
|
||||
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"IntegrationAgent PR creation failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _apply_optimization(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
best_optimization: BestOptimization = payload["best_optimization"]
|
||||
function_to_optimize: FunctionToOptimize = payload["function_to_optimize"]
|
||||
|
||||
read_writable_code = best_optimization.code_context.read_writable_code
|
||||
files_modified = []
|
||||
|
||||
for code_string in read_writable_code.code_strings:
|
||||
file_path = code_string.file_path
|
||||
file_path.write_text(code_string.code, encoding="utf-8")
|
||||
files_modified.append(file_path)
|
||||
|
||||
result_data = {
|
||||
"files_modified": files_modified,
|
||||
"function_name": function_to_optimize.qualified_name,
|
||||
}
|
||||
|
||||
logger.info(f"IntegrationAgent: Applied optimization to {len(files_modified)} files")
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"IntegrationAgent optimization application failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _generate_explanation(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.result.explanation import Explanation
|
||||
|
||||
best_optimization: BestOptimization = payload["best_optimization"]
|
||||
function_to_optimize: FunctionToOptimize = payload["function_to_optimize"]
|
||||
original_runtime: int = payload["original_runtime"]
|
||||
|
||||
explanation = Explanation(
|
||||
best_optimization=best_optimization,
|
||||
function_to_optimize=function_to_optimize,
|
||||
original_runtime=original_runtime,
|
||||
)
|
||||
|
||||
result_data = {
|
||||
"explanation": explanation,
|
||||
"summary": explanation.summary if hasattr(explanation, "summary") else None,
|
||||
}
|
||||
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"IntegrationAgent explanation generation failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
|
||||
def create_pr_task(
|
||||
best_optimization: Any,
|
||||
function_to_optimize: Any,
|
||||
args: Any,
|
||||
) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="create_pr",
|
||||
payload={
|
||||
"best_optimization": best_optimization,
|
||||
"function_to_optimize": function_to_optimize,
|
||||
"args": args,
|
||||
},
|
||||
priority=3,
|
||||
)
|
||||
|
||||
|
||||
def create_apply_optimization_task(
|
||||
best_optimization: Any,
|
||||
function_to_optimize: Any,
|
||||
) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="apply_optimization",
|
||||
payload={
|
||||
"best_optimization": best_optimization,
|
||||
"function_to_optimize": function_to_optimize,
|
||||
},
|
||||
priority=3,
|
||||
)
|
||||
|
||||
|
||||
def create_explanation_task(
|
||||
best_optimization: Any,
|
||||
function_to_optimize: Any,
|
||||
original_runtime: int,
|
||||
) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="generate_explanation",
|
||||
payload={
|
||||
"best_optimization": best_optimization,
|
||||
"function_to_optimize": function_to_optimize,
|
||||
"original_runtime": original_runtime,
|
||||
},
|
||||
priority=2,
|
||||
)
|
||||
207
codeflash/agents/selection_agent.py
Normal file
207
codeflash/agents/selection_agent.py
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.agents.base import AgentResult, AgentTask, BaseAgent, wrap_error, wrap_result
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.either import Result
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.agents.coordinator import AgentCoordinator
|
||||
from codeflash.models.models import ConcurrencyMetrics, OptimizedCandidateResult
|
||||
|
||||
|
||||
class SelectionAgent(BaseAgent):
|
||||
def __init__(self, coordinator: AgentCoordinator | None = None) -> None:
|
||||
super().__init__(agent_id="selection", coordinator=coordinator)
|
||||
|
||||
def process(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
task_type = task.task_type
|
||||
if task_type == "evaluate_candidate":
|
||||
return self._evaluate_candidate(task)
|
||||
if task_type == "select_best":
|
||||
return self._select_best(task)
|
||||
if task_type == "rank_candidates":
|
||||
return self._rank_candidates(task)
|
||||
return wrap_error(f"Unknown task type: {task_type}", task.task_id, self.agent_id)
|
||||
|
||||
def _evaluate_candidate(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.result.critic import coverage_critic, quantity_of_tests_critic, speedup_critic
|
||||
|
||||
candidate_result: OptimizedCandidateResult = payload["candidate_result"]
|
||||
original_runtime: int = payload["original_runtime"]
|
||||
best_runtime_until_now: int | None = payload.get("best_runtime_until_now")
|
||||
original_async_throughput: int | None = payload.get("original_async_throughput")
|
||||
best_throughput_until_now: int | None = payload.get("best_throughput_until_now")
|
||||
original_concurrency_metrics: ConcurrencyMetrics | None = payload.get("original_concurrency_metrics")
|
||||
best_concurrency_ratio_until_now: float | None = payload.get("best_concurrency_ratio_until_now")
|
||||
original_coverage = payload.get("original_coverage")
|
||||
|
||||
passes_quantity_check = quantity_of_tests_critic(candidate_result)
|
||||
passes_speedup_check = speedup_critic(
|
||||
candidate_result,
|
||||
original_runtime,
|
||||
best_runtime_until_now,
|
||||
original_async_throughput=original_async_throughput,
|
||||
best_throughput_until_now=best_throughput_until_now,
|
||||
original_concurrency_metrics=original_concurrency_metrics,
|
||||
best_concurrency_ratio_until_now=best_concurrency_ratio_until_now,
|
||||
)
|
||||
passes_coverage_check = coverage_critic(original_coverage) if original_coverage else True
|
||||
|
||||
is_acceptable = passes_quantity_check and passes_speedup_check and passes_coverage_check
|
||||
|
||||
result_data = {
|
||||
"candidate_result": candidate_result,
|
||||
"is_acceptable": is_acceptable,
|
||||
"passes_quantity_check": passes_quantity_check,
|
||||
"passes_speedup_check": passes_speedup_check,
|
||||
"passes_coverage_check": passes_coverage_check,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
f"SelectionAgent: Evaluated candidate, acceptable={is_acceptable} "
|
||||
f"(qty={passes_quantity_check}, speed={passes_speedup_check}, cov={passes_coverage_check})"
|
||||
)
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"SelectionAgent evaluation failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _select_best(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.result.critic import get_acceptance_reason, performance_gain
|
||||
|
||||
candidates: list[dict[str, Any]] = payload["candidates"]
|
||||
original_runtime: int = payload["original_runtime"]
|
||||
original_async_throughput: int | None = payload.get("original_async_throughput")
|
||||
original_concurrency_metrics: ConcurrencyMetrics | None = payload.get("original_concurrency_metrics")
|
||||
|
||||
if not candidates:
|
||||
return wrap_result({"best_candidate": None, "reason": "No candidates provided"}, task.task_id, self.agent_id)
|
||||
|
||||
acceptable_candidates = [c for c in candidates if c.get("is_acceptable", False)]
|
||||
|
||||
if not acceptable_candidates:
|
||||
return wrap_result({"best_candidate": None, "reason": "No acceptable candidates"}, task.task_id, self.agent_id)
|
||||
|
||||
best = min(acceptable_candidates, key=lambda c: c.get("runtime", float("inf")))
|
||||
best_runtime = best.get("runtime", original_runtime)
|
||||
|
||||
perf_gain = performance_gain(original_runtime_ns=original_runtime, optimized_runtime_ns=best_runtime)
|
||||
speedup = (original_runtime / best_runtime) if best_runtime > 0 else 0.0
|
||||
|
||||
acceptance_reason = get_acceptance_reason(
|
||||
original_runtime_ns=original_runtime,
|
||||
optimized_runtime_ns=best_runtime,
|
||||
original_async_throughput=original_async_throughput,
|
||||
optimized_async_throughput=best.get("async_throughput"),
|
||||
original_concurrency_metrics=original_concurrency_metrics,
|
||||
optimized_concurrency_metrics=best.get("concurrency_metrics"),
|
||||
)
|
||||
|
||||
result_data = {
|
||||
"best_candidate": best,
|
||||
"performance_gain": perf_gain,
|
||||
"speedup": speedup,
|
||||
"acceptance_reason": acceptance_reason.value,
|
||||
}
|
||||
|
||||
logger.info(f"SelectionAgent: Selected best candidate with {speedup:.2f}x speedup")
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"SelectionAgent selection failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _rank_candidates(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.result.critic import performance_gain
|
||||
|
||||
candidates: list[dict[str, Any]] = payload["candidates"]
|
||||
original_runtime: int = payload["original_runtime"]
|
||||
|
||||
ranked = []
|
||||
for candidate in candidates:
|
||||
runtime = candidate.get("runtime", float("inf"))
|
||||
perf = performance_gain(original_runtime_ns=original_runtime, optimized_runtime_ns=runtime)
|
||||
ranked.append({
|
||||
**candidate,
|
||||
"performance_gain": perf,
|
||||
"speedup": (original_runtime / runtime) if runtime > 0 else 0.0,
|
||||
})
|
||||
|
||||
ranked.sort(key=lambda c: c.get("performance_gain", 0), reverse=True)
|
||||
|
||||
result_data = {
|
||||
"ranked_candidates": ranked,
|
||||
}
|
||||
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"SelectionAgent ranking failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
|
||||
def create_evaluation_task(
|
||||
candidate_result: Any,
|
||||
original_runtime: int,
|
||||
best_runtime_until_now: int | None = None,
|
||||
original_async_throughput: int | None = None,
|
||||
best_throughput_until_now: int | None = None,
|
||||
original_concurrency_metrics: Any | None = None,
|
||||
best_concurrency_ratio_until_now: float | None = None,
|
||||
original_coverage: Any | None = None,
|
||||
) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="evaluate_candidate",
|
||||
payload={
|
||||
"candidate_result": candidate_result,
|
||||
"original_runtime": original_runtime,
|
||||
"best_runtime_until_now": best_runtime_until_now,
|
||||
"original_async_throughput": original_async_throughput,
|
||||
"best_throughput_until_now": best_throughput_until_now,
|
||||
"original_concurrency_metrics": original_concurrency_metrics,
|
||||
"best_concurrency_ratio_until_now": best_concurrency_ratio_until_now,
|
||||
"original_coverage": original_coverage,
|
||||
},
|
||||
priority=5,
|
||||
)
|
||||
|
||||
|
||||
def create_selection_task(
|
||||
candidates: list[dict[str, Any]],
|
||||
original_runtime: int,
|
||||
original_async_throughput: int | None = None,
|
||||
original_concurrency_metrics: Any | None = None,
|
||||
) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="select_best",
|
||||
payload={
|
||||
"candidates": candidates,
|
||||
"original_runtime": original_runtime,
|
||||
"original_async_throughput": original_async_throughput,
|
||||
"original_concurrency_metrics": original_concurrency_metrics,
|
||||
},
|
||||
priority=4,
|
||||
)
|
||||
|
||||
|
||||
def create_ranking_task(
|
||||
candidates: list[dict[str, Any]],
|
||||
original_runtime: int,
|
||||
) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="rank_candidates",
|
||||
payload={
|
||||
"candidates": candidates,
|
||||
"original_runtime": original_runtime,
|
||||
},
|
||||
priority=4,
|
||||
)
|
||||
282
codeflash/agents/verification_agent.py
Normal file
282
codeflash/agents/verification_agent.py
Normal file
|
|
@ -0,0 +1,282 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.agents.base import AgentResult, AgentTask, BaseAgent, wrap_error, wrap_result
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.either import Result
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.agents.coordinator import AgentCoordinator
|
||||
from codeflash.models.models import OptimizedCandidate, TestFiles
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
class VerificationAgent(BaseAgent):
|
||||
def __init__(self, coordinator: AgentCoordinator | None = None) -> None:
|
||||
super().__init__(agent_id="verification", coordinator=coordinator)
|
||||
|
||||
def process(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
task_type = task.task_type
|
||||
if task_type == "generate_tests":
|
||||
return self._generate_tests(task)
|
||||
if task_type == "run_behavioral_tests":
|
||||
return self._run_behavioral_tests(task)
|
||||
if task_type == "run_benchmark_tests":
|
||||
return self._run_benchmark_tests(task)
|
||||
if task_type == "establish_baseline":
|
||||
return self._establish_baseline(task)
|
||||
if task_type == "verify_candidate":
|
||||
return self._verify_candidate(task)
|
||||
return wrap_error(f"Unknown task type: {task_type}", task.task_id, self.agent_id)
|
||||
|
||||
def _generate_tests(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.verification.verifier import generate_tests
|
||||
|
||||
code_context = payload["code_context"]
|
||||
function_to_optimize = payload["function_to_optimize"]
|
||||
test_cfg: TestConfig = payload["test_cfg"]
|
||||
project_root: Path = payload["project_root"]
|
||||
trace_id: str = payload["trace_id"]
|
||||
|
||||
generated_tests = generate_tests(
|
||||
code_context=code_context,
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
project_root_path=project_root,
|
||||
trace_id=trace_id,
|
||||
)
|
||||
|
||||
result_data = {
|
||||
"generated_tests": generated_tests,
|
||||
}
|
||||
|
||||
logger.debug(f"VerificationAgent: Generated tests for {function_to_optimize.qualified_name}")
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"VerificationAgent test generation failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _run_behavioral_tests(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.verification.test_runner import run_behavioral_tests
|
||||
|
||||
test_files: TestFiles = payload["test_files"]
|
||||
test_cfg: TestConfig = payload["test_cfg"]
|
||||
project_root: Path = payload["project_root"]
|
||||
test_timeout: int = payload.get("test_timeout", 60)
|
||||
|
||||
test_output = run_behavioral_tests(
|
||||
test_files=test_files,
|
||||
test_cfg=test_cfg,
|
||||
project_root_path=project_root,
|
||||
test_timeout=test_timeout,
|
||||
)
|
||||
|
||||
result_data = {
|
||||
"test_output": test_output,
|
||||
}
|
||||
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"VerificationAgent behavioral tests failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _run_benchmark_tests(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.verification.test_runner import run_benchmarking_tests
|
||||
|
||||
test_files: TestFiles = payload["test_files"]
|
||||
test_cfg: TestConfig = payload["test_cfg"]
|
||||
project_root: Path = payload["project_root"]
|
||||
test_timeout: int = payload.get("test_timeout", 300)
|
||||
|
||||
benchmark_output = run_benchmarking_tests(
|
||||
test_files=test_files,
|
||||
test_cfg=test_cfg,
|
||||
project_root_path=project_root,
|
||||
test_timeout=test_timeout,
|
||||
)
|
||||
|
||||
result_data = {
|
||||
"benchmark_output": benchmark_output,
|
||||
}
|
||||
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"VerificationAgent benchmark tests failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _establish_baseline(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.verification.parse_test_output import parse_test_results
|
||||
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests
|
||||
|
||||
test_files: TestFiles = payload["test_files"]
|
||||
test_cfg: TestConfig = payload["test_cfg"]
|
||||
project_root: Path = payload["project_root"]
|
||||
function_qualified_name: str = payload["function_qualified_name"]
|
||||
|
||||
behavior_output = run_behavioral_tests(
|
||||
test_files=test_files,
|
||||
test_cfg=test_cfg,
|
||||
project_root_path=project_root,
|
||||
)
|
||||
|
||||
behavior_results = parse_test_results(
|
||||
test_output=behavior_output,
|
||||
function_qualified_name=function_qualified_name,
|
||||
)
|
||||
|
||||
benchmark_output = run_benchmarking_tests(
|
||||
test_files=test_files,
|
||||
test_cfg=test_cfg,
|
||||
project_root_path=project_root,
|
||||
)
|
||||
|
||||
benchmark_results = parse_test_results(
|
||||
test_output=benchmark_output,
|
||||
function_qualified_name=function_qualified_name,
|
||||
)
|
||||
|
||||
result_data = {
|
||||
"behavior_results": behavior_results,
|
||||
"benchmark_results": benchmark_results,
|
||||
"behavior_output": behavior_output,
|
||||
"benchmark_output": benchmark_output,
|
||||
}
|
||||
|
||||
logger.debug(f"VerificationAgent: Established baseline for {function_qualified_name}")
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"VerificationAgent baseline establishment failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
def _verify_candidate(self, task: AgentTask) -> Result[AgentResult[Any], str]:
|
||||
payload = task.payload
|
||||
try:
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
from codeflash.verification.parse_test_output import parse_test_results
|
||||
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests
|
||||
|
||||
candidate: OptimizedCandidate = payload["candidate"]
|
||||
test_files: TestFiles = payload["test_files"]
|
||||
test_cfg: TestConfig = payload["test_cfg"]
|
||||
project_root: Path = payload["project_root"]
|
||||
function_qualified_name: str = payload["function_qualified_name"]
|
||||
original_behavior_results = payload["original_behavior_results"]
|
||||
|
||||
behavior_output = run_behavioral_tests(
|
||||
test_files=test_files,
|
||||
test_cfg=test_cfg,
|
||||
project_root_path=project_root,
|
||||
)
|
||||
|
||||
behavior_results = parse_test_results(
|
||||
test_output=behavior_output,
|
||||
function_qualified_name=function_qualified_name,
|
||||
)
|
||||
|
||||
is_equivalent, diffs = compare_test_results(
|
||||
original_results=original_behavior_results,
|
||||
optimized_results=behavior_results,
|
||||
)
|
||||
|
||||
benchmark_results = None
|
||||
benchmark_output = None
|
||||
if is_equivalent:
|
||||
benchmark_output = run_benchmarking_tests(
|
||||
test_files=test_files,
|
||||
test_cfg=test_cfg,
|
||||
project_root_path=project_root,
|
||||
)
|
||||
|
||||
benchmark_results = parse_test_results(
|
||||
test_output=benchmark_output,
|
||||
function_qualified_name=function_qualified_name,
|
||||
)
|
||||
|
||||
result_data = {
|
||||
"candidate": candidate,
|
||||
"is_equivalent": is_equivalent,
|
||||
"diffs": diffs,
|
||||
"behavior_results": behavior_results,
|
||||
"benchmark_results": benchmark_results,
|
||||
}
|
||||
|
||||
logger.debug(f"VerificationAgent: Verified candidate {candidate.optimization_id}, equivalent={is_equivalent}")
|
||||
return wrap_result(result_data, task.task_id, self.agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"VerificationAgent candidate verification failed: {e}")
|
||||
return wrap_error(str(e), task.task_id, self.agent_id)
|
||||
|
||||
|
||||
def create_test_generation_task(
|
||||
code_context: Any,
|
||||
function_to_optimize: Any,
|
||||
test_cfg: Any,
|
||||
project_root: Path,
|
||||
trace_id: str,
|
||||
) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="generate_tests",
|
||||
payload={
|
||||
"code_context": code_context,
|
||||
"function_to_optimize": function_to_optimize,
|
||||
"test_cfg": test_cfg,
|
||||
"project_root": project_root,
|
||||
"trace_id": trace_id,
|
||||
},
|
||||
priority=8,
|
||||
)
|
||||
|
||||
|
||||
def create_baseline_task(
|
||||
test_files: Any,
|
||||
test_cfg: Any,
|
||||
project_root: Path,
|
||||
function_qualified_name: str,
|
||||
) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="establish_baseline",
|
||||
payload={
|
||||
"test_files": test_files,
|
||||
"test_cfg": test_cfg,
|
||||
"project_root": project_root,
|
||||
"function_qualified_name": function_qualified_name,
|
||||
},
|
||||
priority=7,
|
||||
)
|
||||
|
||||
|
||||
def create_verification_task(
|
||||
candidate: Any,
|
||||
test_files: Any,
|
||||
test_cfg: Any,
|
||||
project_root: Path,
|
||||
function_qualified_name: str,
|
||||
original_behavior_results: Any,
|
||||
) -> AgentTask:
|
||||
return AgentTask.create(
|
||||
task_type="verify_candidate",
|
||||
payload={
|
||||
"candidate": candidate,
|
||||
"test_files": test_files,
|
||||
"test_cfg": test_cfg,
|
||||
"project_root": project_root,
|
||||
"function_qualified_name": function_qualified_name,
|
||||
"original_behavior_results": original_behavior_results,
|
||||
},
|
||||
priority=6,
|
||||
)
|
||||
|
|
@ -130,6 +130,11 @@ def parse_args() -> Namespace:
|
|||
"--reset-config", action="store_true", help="Remove codeflash configuration from project config file."
|
||||
)
|
||||
parser.add_argument("-y", "--yes", action="store_true", help="Skip confirmation prompts (useful for CI/scripts).")
|
||||
parser.add_argument(
|
||||
"--agentic",
|
||||
action="store_true",
|
||||
help="Use the agentic multi-agent optimization pipeline instead of the default batch pipeline.",
|
||||
)
|
||||
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
sys.argv[:] = [sys.argv[0], *unknown_args]
|
||||
|
|
|
|||
|
|
@ -59,9 +59,14 @@ def main() -> None:
|
|||
init_sentry(enabled=not args.disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
|
||||
|
||||
from codeflash.optimization import optimizer
|
||||
if getattr(args, "agentic", False):
|
||||
from codeflash.optimization.agentic_optimizer import run_agentic_with_args
|
||||
|
||||
optimizer.run_with_args(args)
|
||||
run_agentic_with_args(args)
|
||||
else:
|
||||
from codeflash.optimization import optimizer
|
||||
|
||||
optimizer.run_with_args(args)
|
||||
|
||||
|
||||
def _handle_config_loading(args: Namespace) -> Namespace | None:
|
||||
|
|
|
|||
203
codeflash/optimization/agentic_optimizer.py
Normal file
203
codeflash/optimization/agentic_optimizer.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.agents.coordinator import AgentCoordinator
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_utils import cleanup_paths
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.languages import is_javascript, set_current_language
|
||||
from codeflash.state.store import StateStore
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
|
||||
class AgenticOptimizer:
|
||||
def __init__(self, args: Namespace) -> None:
|
||||
self.args = args
|
||||
self.state_store = StateStore()
|
||||
self.coordinator = AgentCoordinator(state_store=self.state_store)
|
||||
|
||||
self.test_cfg = TestConfig(
|
||||
tests_root=args.tests_root,
|
||||
tests_project_rootdir=args.test_project_root,
|
||||
project_root_path=args.project_root,
|
||||
pytest_cmd=args.pytest_cmd if hasattr(args, "pytest_cmd") and args.pytest_cmd else "pytest",
|
||||
benchmark_tests_root=args.benchmarks_root if "benchmark" in args and "benchmarks_root" in args else None,
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
ph("cli-agentic-optimize-run-start")
|
||||
logger.info("Running agentic optimizer.")
|
||||
console.rule()
|
||||
|
||||
if not env_utils.ensure_codeflash_api_key():
|
||||
return
|
||||
|
||||
self.coordinator.register_agents()
|
||||
|
||||
file_to_funcs, num_functions, trace_file_path = self._get_optimizable_functions()
|
||||
|
||||
if file_to_funcs:
|
||||
for file_path, funcs in file_to_funcs.items():
|
||||
if funcs and funcs[0].language:
|
||||
set_current_language(funcs[0].language)
|
||||
self.test_cfg.set_language(funcs[0].language)
|
||||
if is_javascript():
|
||||
self.test_cfg.js_project_root = self._find_js_project_root(file_path)
|
||||
break
|
||||
|
||||
if num_functions == 0:
|
||||
logger.info("No functions found to optimize. Exiting…")
|
||||
return
|
||||
|
||||
function_to_tests = self._discover_tests(file_to_funcs)
|
||||
|
||||
self.test_cfg.concolic_test_root_dir = Path(
|
||||
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
|
||||
)
|
||||
|
||||
try:
|
||||
optimizations_found = 0
|
||||
|
||||
globally_ranked = self._rank_functions(file_to_funcs, trace_file_path)
|
||||
|
||||
for i, (file_path, function_to_optimize) in enumerate(globally_ranked):
|
||||
logger.info(
|
||||
f"Optimizing function {i + 1} of {len(globally_ranked)}: "
|
||||
f"{function_to_optimize.qualified_name} (in {file_path.name})"
|
||||
)
|
||||
console.rule()
|
||||
|
||||
result = self.coordinator.run_optimization_pipeline(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=self.test_cfg,
|
||||
args=self.args,
|
||||
function_to_tests=function_to_tests,
|
||||
)
|
||||
|
||||
if is_successful(result):
|
||||
optimizations_found += 1
|
||||
logger.info(f"Successfully optimized {function_to_optimize.qualified_name}")
|
||||
else:
|
||||
logger.warning(f"Failed to optimize {function_to_optimize.qualified_name}: {result.failure()}")
|
||||
|
||||
console.rule()
|
||||
|
||||
ph("cli-agentic-optimize-run-finished", {"optimizations_found": optimizations_found})
|
||||
|
||||
if optimizations_found == 0:
|
||||
logger.info("No optimizations found.")
|
||||
else:
|
||||
logger.info(f"Found {optimizations_found} optimization(s).")
|
||||
|
||||
finally:
|
||||
cleanup_paths([self.test_cfg.concolic_test_root_dir])
|
||||
|
||||
def _get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
|
||||
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
|
||||
|
||||
return get_functions_to_optimize(
|
||||
optimize_all=self.args.all,
|
||||
replay_test=self.args.replay_test,
|
||||
file=self.args.file,
|
||||
only_get_this_function=self.args.function,
|
||||
test_cfg=self.test_cfg,
|
||||
ignore_paths=self.args.ignore_paths,
|
||||
project_root=self.args.project_root,
|
||||
module_root=self.args.module_root,
|
||||
previous_checkpoint_functions=self.args.previous_checkpoint_functions,
|
||||
)
|
||||
|
||||
def _discover_tests(
|
||||
self, file_to_funcs: dict[Path, list[FunctionToOptimize]]
|
||||
) -> dict[str, set]:
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
|
||||
console.rule()
|
||||
start_time = time.time()
|
||||
logger.info("Discovering existing function tests...")
|
||||
|
||||
function_to_tests, num_tests, num_replay_tests = discover_unit_tests(
|
||||
self.test_cfg, file_to_funcs_to_optimize=file_to_funcs
|
||||
)
|
||||
|
||||
console.rule()
|
||||
logger.info(
|
||||
f"Discovered {num_tests} existing tests and {num_replay_tests} replay tests "
|
||||
f"in {(time.time() - start_time):.1f}s"
|
||||
)
|
||||
console.rule()
|
||||
|
||||
return function_to_tests
|
||||
|
||||
def _rank_functions(
|
||||
self,
|
||||
file_to_funcs: dict[Path, list[FunctionToOptimize]],
|
||||
trace_file_path: Path | None,
|
||||
) -> list[tuple[Path, FunctionToOptimize]]:
|
||||
all_functions: list[tuple[Path, FunctionToOptimize]] = []
|
||||
for file_path, functions in file_to_funcs.items():
|
||||
all_functions.extend((file_path, func) for func in functions)
|
||||
|
||||
if not trace_file_path or not trace_file_path.exists():
|
||||
return all_functions
|
||||
|
||||
try:
|
||||
from codeflash.benchmarking.function_ranker import FunctionRanker
|
||||
|
||||
logger.info("Ranking functions by performance impact...")
|
||||
ranker = FunctionRanker(trace_file_path)
|
||||
functions_only = [func for _, func in all_functions]
|
||||
ranked_functions = ranker.rank_functions(functions_only)
|
||||
|
||||
func_to_file_map = {}
|
||||
for file_path, func in all_functions:
|
||||
key = (func.file_path, func.qualified_name, func.starting_line)
|
||||
func_to_file_map[key] = file_path
|
||||
|
||||
globally_ranked = []
|
||||
for func in ranked_functions:
|
||||
key = (func.file_path, func.qualified_name, func.starting_line)
|
||||
file_path = func_to_file_map.get(key)
|
||||
if file_path:
|
||||
globally_ranked.append((file_path, func))
|
||||
|
||||
logger.info(f"Ranked {len(ranked_functions)} functions by addressable time")
|
||||
return globally_ranked
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not perform ranking: {e}")
|
||||
return all_functions
|
||||
|
||||
def _find_js_project_root(self, file_path: Path) -> Path | None:
|
||||
current = file_path.parent if file_path.is_file() else file_path
|
||||
while current != current.parent:
|
||||
if (
|
||||
(current / "package.json").exists()
|
||||
or (current / "jest.config.js").exists()
|
||||
or (current / "jest.config.ts").exists()
|
||||
or (current / "tsconfig.json").exists()
|
||||
):
|
||||
return current
|
||||
current = current.parent
|
||||
return None
|
||||
|
||||
|
||||
def run_agentic_with_args(args: Namespace) -> None:
|
||||
optimizer = None
|
||||
try:
|
||||
optimizer = AgenticOptimizer(args)
|
||||
optimizer.run()
|
||||
except KeyboardInterrupt:
|
||||
logger.warning("Keyboard interrupt received. Cleaning up and exiting...")
|
||||
raise SystemExit from None
|
||||
13
codeflash/state/__init__.py
Normal file
13
codeflash/state/__init__.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from codeflash.state.history import OptimizationHistory
|
||||
from codeflash.state.models import AgentStateSnapshot, OptimizationAttempt, OptimizationStatus
|
||||
from codeflash.state.store import StateStore
|
||||
|
||||
__all__ = [
|
||||
"AgentStateSnapshot",
|
||||
"OptimizationAttempt",
|
||||
"OptimizationHistory",
|
||||
"OptimizationStatus",
|
||||
"StateStore",
|
||||
]
|
||||
67
codeflash/state/history.py
Normal file
67
codeflash/state/history.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.state.models import OptimizationAttempt, OptimizationStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.state.store import StateStore
|
||||
|
||||
|
||||
class OptimizationHistory:
|
||||
def __init__(self, store: StateStore) -> None:
|
||||
self.store = store
|
||||
|
||||
def record_attempt(self, attempt: OptimizationAttempt) -> None:
|
||||
self.store.persist_optimization_attempt(attempt)
|
||||
|
||||
def get_function_attempts(self, qualified_name: str, limit: int = 100) -> list[OptimizationAttempt]:
|
||||
return self.store.get_function_history(qualified_name, limit=limit)
|
||||
|
||||
def get_successful_optimizations(
|
||||
self, qualified_name: str | None = None, limit: int = 100
|
||||
) -> list[OptimizationAttempt]:
|
||||
if qualified_name:
|
||||
attempts = self.store.get_function_history(qualified_name, limit=limit)
|
||||
return [a for a in attempts if a.status == OptimizationStatus.COMPLETED]
|
||||
return self.store.get_recent_attempts(status=OptimizationStatus.COMPLETED, limit=limit)
|
||||
|
||||
def get_failed_optimizations(self, qualified_name: str | None = None, limit: int = 100) -> list[OptimizationAttempt]:
|
||||
if qualified_name:
|
||||
attempts = self.store.get_function_history(qualified_name, limit=limit)
|
||||
return [a for a in attempts if a.status == OptimizationStatus.FAILED]
|
||||
return self.store.get_recent_attempts(status=OptimizationStatus.FAILED, limit=limit)
|
||||
|
||||
def should_skip_function(self, qualified_name: str, code_hash: str | None = None) -> tuple[bool, str | None]:
|
||||
if code_hash and self.store.was_function_recently_optimized(qualified_name, code_hash=code_hash):
|
||||
return True, "Function with same code hash was recently optimized"
|
||||
|
||||
recent_attempts = self.store.get_function_history(qualified_name, limit=5)
|
||||
|
||||
completed_count = sum(1 for a in recent_attempts if a.status == OptimizationStatus.COMPLETED)
|
||||
if completed_count >= 3:
|
||||
return True, "Function has been successfully optimized multiple times recently"
|
||||
|
||||
failed_count = sum(1 for a in recent_attempts if a.status == OptimizationStatus.FAILED)
|
||||
if failed_count >= 3:
|
||||
return True, "Function has failed optimization multiple times recently"
|
||||
|
||||
return False, None
|
||||
|
||||
def get_best_speedup(self, qualified_name: str) -> float | None:
|
||||
attempts = self.store.get_function_history(qualified_name)
|
||||
speedups = [a.speedup for a in attempts if a.speedup is not None and a.status == OptimizationStatus.COMPLETED]
|
||||
return max(speedups) if speedups else None
|
||||
|
||||
def get_statistics(self) -> dict[str, int]:
|
||||
recent = self.store.get_recent_attempts(limit=1000)
|
||||
return {
|
||||
"total_attempts": len(recent),
|
||||
"completed": sum(1 for a in recent if a.status == OptimizationStatus.COMPLETED),
|
||||
"failed": sum(1 for a in recent if a.status == OptimizationStatus.FAILED),
|
||||
"skipped": sum(1 for a in recent if a.status == OptimizationStatus.SKIPPED),
|
||||
"in_progress": sum(1 for a in recent if a.status == OptimizationStatus.IN_PROGRESS),
|
||||
}
|
||||
|
||||
def cleanup(self, days: int = 30) -> int:
|
||||
return self.store.cleanup_old_records(days=days)
|
||||
172
codeflash/state/models.py
Normal file
172
codeflash/state/models.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class OptimizationStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
class OptimizationAttempt(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
attempt_id: str
|
||||
function_qualified_name: str
|
||||
file_path: str
|
||||
status: OptimizationStatus
|
||||
started_at: float
|
||||
completed_at: float | None = None
|
||||
speedup: float | None = None
|
||||
original_runtime_ns: int | None = None
|
||||
optimized_runtime_ns: int | None = None
|
||||
error_message: str | None = None
|
||||
pr_url: str | None = None
|
||||
code_hash: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
attempt_id: str,
|
||||
function_qualified_name: str,
|
||||
file_path: str | Path,
|
||||
code_hash: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> OptimizationAttempt:
|
||||
return cls(
|
||||
attempt_id=attempt_id,
|
||||
function_qualified_name=function_qualified_name,
|
||||
file_path=str(file_path),
|
||||
status=OptimizationStatus.PENDING,
|
||||
started_at=time.time(),
|
||||
code_hash=code_hash,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
def mark_in_progress(self) -> OptimizationAttempt:
|
||||
return OptimizationAttempt(
|
||||
attempt_id=self.attempt_id,
|
||||
function_qualified_name=self.function_qualified_name,
|
||||
file_path=self.file_path,
|
||||
status=OptimizationStatus.IN_PROGRESS,
|
||||
started_at=self.started_at,
|
||||
code_hash=self.code_hash,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
def mark_completed(
|
||||
self,
|
||||
speedup: float,
|
||||
original_runtime_ns: int,
|
||||
optimized_runtime_ns: int,
|
||||
pr_url: str | None = None,
|
||||
) -> OptimizationAttempt:
|
||||
return OptimizationAttempt(
|
||||
attempt_id=self.attempt_id,
|
||||
function_qualified_name=self.function_qualified_name,
|
||||
file_path=self.file_path,
|
||||
status=OptimizationStatus.COMPLETED,
|
||||
started_at=self.started_at,
|
||||
completed_at=time.time(),
|
||||
speedup=speedup,
|
||||
original_runtime_ns=original_runtime_ns,
|
||||
optimized_runtime_ns=optimized_runtime_ns,
|
||||
pr_url=pr_url,
|
||||
code_hash=self.code_hash,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
def mark_failed(self, error_message: str) -> OptimizationAttempt:
|
||||
return OptimizationAttempt(
|
||||
attempt_id=self.attempt_id,
|
||||
function_qualified_name=self.function_qualified_name,
|
||||
file_path=self.file_path,
|
||||
status=OptimizationStatus.FAILED,
|
||||
started_at=self.started_at,
|
||||
completed_at=time.time(),
|
||||
error_message=error_message,
|
||||
code_hash=self.code_hash,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
def mark_skipped(self, reason: str) -> OptimizationAttempt:
|
||||
return OptimizationAttempt(
|
||||
attempt_id=self.attempt_id,
|
||||
function_qualified_name=self.function_qualified_name,
|
||||
file_path=self.file_path,
|
||||
status=OptimizationStatus.SKIPPED,
|
||||
started_at=self.started_at,
|
||||
completed_at=time.time(),
|
||||
error_message=reason,
|
||||
code_hash=self.code_hash,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
|
||||
class AgentStateSnapshot(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
agent_id: str
|
||||
agent_type: str
|
||||
state: str
|
||||
current_task_id: str | None = None
|
||||
last_updated: float = Field(default_factory=time.time)
|
||||
context: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
agent_id: str,
|
||||
agent_type: str,
|
||||
state: str,
|
||||
current_task_id: str | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> AgentStateSnapshot:
|
||||
return cls(
|
||||
agent_id=agent_id,
|
||||
agent_type=agent_type,
|
||||
state=state,
|
||||
current_task_id=current_task_id,
|
||||
context=context or {},
|
||||
)
|
||||
|
||||
|
||||
class PipelineState(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
pipeline_id: str
|
||||
status: str
|
||||
current_stage: str | None = None
|
||||
started_at: float
|
||||
completed_at: float | None = None
|
||||
function_qualified_name: str
|
||||
file_path: str
|
||||
stages_completed: list[str] = Field(default_factory=list)
|
||||
error_message: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
pipeline_id: str,
|
||||
function_qualified_name: str,
|
||||
file_path: str | Path,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> PipelineState:
|
||||
return cls(
|
||||
pipeline_id=pipeline_id,
|
||||
status="pending",
|
||||
started_at=time.time(),
|
||||
function_qualified_name=function_qualified_name,
|
||||
file_path=str(file_path),
|
||||
metadata=metadata or {},
|
||||
)
|
||||
310
codeflash/state/store.py
Normal file
310
codeflash/state/store.py
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from codeflash.state.models import AgentStateSnapshot, OptimizationAttempt, OptimizationStatus, PipelineState
|
||||
|
||||
|
||||
class StateStore:
|
||||
def __init__(self, storage_path: Path | None = None) -> None:
|
||||
if storage_path is None:
|
||||
from codeflash.code_utils.compat import codeflash_temp_dir
|
||||
|
||||
storage_path = codeflash_temp_dir
|
||||
self.storage_path = storage_path
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
self.db_path = self.storage_path / "codeflash_agent.db"
|
||||
self._init_database()
|
||||
|
||||
def _init_database(self) -> None:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS optimization_attempts (
|
||||
attempt_id TEXT PRIMARY KEY,
|
||||
function_qualified_name TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
started_at REAL NOT NULL,
|
||||
completed_at REAL,
|
||||
speedup REAL,
|
||||
original_runtime_ns INTEGER,
|
||||
optimized_runtime_ns INTEGER,
|
||||
error_message TEXT,
|
||||
pr_url TEXT,
|
||||
code_hash TEXT,
|
||||
metadata TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_function_name
|
||||
ON optimization_attempts(function_qualified_name)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_status
|
||||
ON optimization_attempts(status)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS agent_states (
|
||||
agent_id TEXT PRIMARY KEY,
|
||||
agent_type TEXT NOT NULL,
|
||||
state TEXT NOT NULL,
|
||||
current_task_id TEXT,
|
||||
last_updated REAL NOT NULL,
|
||||
context TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS pipeline_states (
|
||||
pipeline_id TEXT PRIMARY KEY,
|
||||
status TEXT NOT NULL,
|
||||
current_stage TEXT,
|
||||
started_at REAL NOT NULL,
|
||||
completed_at REAL,
|
||||
function_qualified_name TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
stages_completed TEXT,
|
||||
error_message TEXT,
|
||||
metadata TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
|
||||
@contextmanager
|
||||
def _get_connection(self) -> Iterator[sqlite3.Connection]:
|
||||
conn = sqlite3.connect(str(self.db_path), timeout=30.0)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def persist_optimization_attempt(self, attempt: OptimizationAttempt) -> None:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO optimization_attempts
|
||||
(attempt_id, function_qualified_name, file_path, status, started_at,
|
||||
completed_at, speedup, original_runtime_ns, optimized_runtime_ns,
|
||||
error_message, pr_url, code_hash, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
attempt.attempt_id,
|
||||
attempt.function_qualified_name,
|
||||
attempt.file_path,
|
||||
attempt.status.value,
|
||||
attempt.started_at,
|
||||
attempt.completed_at,
|
||||
attempt.speedup,
|
||||
attempt.original_runtime_ns,
|
||||
attempt.optimized_runtime_ns,
|
||||
attempt.error_message,
|
||||
attempt.pr_url,
|
||||
attempt.code_hash,
|
||||
json.dumps(attempt.metadata),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_optimization_attempt(self, attempt_id: str) -> OptimizationAttempt | None:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM optimization_attempts WHERE attempt_id = ?", (attempt_id,))
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_optimization_attempt(row)
|
||||
|
||||
def get_function_history(self, qualified_name: str, limit: int = 100) -> list[OptimizationAttempt]:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM optimization_attempts
|
||||
WHERE function_qualified_name = ?
|
||||
ORDER BY started_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(qualified_name, limit),
|
||||
)
|
||||
return [self._row_to_optimization_attempt(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_recent_attempts(
|
||||
self, since_timestamp: float | None = None, status: OptimizationStatus | None = None, limit: int = 100
|
||||
) -> list[OptimizationAttempt]:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
query = "SELECT * FROM optimization_attempts WHERE 1=1"
|
||||
params: list[Any] = []
|
||||
|
||||
if since_timestamp is not None:
|
||||
query += " AND started_at >= ?"
|
||||
params.append(since_timestamp)
|
||||
|
||||
if status is not None:
|
||||
query += " AND status = ?"
|
||||
params.append(status.value)
|
||||
|
||||
query += " ORDER BY started_at DESC LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
cursor.execute(query, params)
|
||||
return [self._row_to_optimization_attempt(row) for row in cursor.fetchall()]
|
||||
|
||||
def was_function_recently_optimized(
|
||||
self, qualified_name: str, code_hash: str | None = None, within_days: int = 7
|
||||
) -> bool:
|
||||
cutoff_time = time.time() - (within_days * 24 * 60 * 60)
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
if code_hash:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM optimization_attempts
|
||||
WHERE function_qualified_name = ?
|
||||
AND code_hash = ?
|
||||
AND status = ?
|
||||
AND started_at >= ?
|
||||
""",
|
||||
(qualified_name, code_hash, OptimizationStatus.COMPLETED.value, cutoff_time),
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM optimization_attempts
|
||||
WHERE function_qualified_name = ?
|
||||
AND status = ?
|
||||
AND started_at >= ?
|
||||
""",
|
||||
(qualified_name, OptimizationStatus.COMPLETED.value, cutoff_time),
|
||||
)
|
||||
|
||||
count = cursor.fetchone()[0]
|
||||
return count > 0
|
||||
|
||||
def persist_agent_state(self, snapshot: AgentStateSnapshot) -> None:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO agent_states
|
||||
(agent_id, agent_type, state, current_task_id, last_updated, context)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
snapshot.agent_id,
|
||||
snapshot.agent_type,
|
||||
snapshot.state,
|
||||
snapshot.current_task_id,
|
||||
snapshot.last_updated,
|
||||
json.dumps(snapshot.context),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_agent_state(self, agent_id: str) -> AgentStateSnapshot | None:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM agent_states WHERE agent_id = ?", (agent_id,))
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return AgentStateSnapshot(
|
||||
agent_id=row["agent_id"],
|
||||
agent_type=row["agent_type"],
|
||||
state=row["state"],
|
||||
current_task_id=row["current_task_id"],
|
||||
last_updated=row["last_updated"],
|
||||
context=json.loads(row["context"]) if row["context"] else {},
|
||||
)
|
||||
|
||||
def persist_pipeline_state(self, state: PipelineState) -> None:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO pipeline_states
|
||||
(pipeline_id, status, current_stage, started_at, completed_at,
|
||||
function_qualified_name, file_path, stages_completed, error_message, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
state.pipeline_id,
|
||||
state.status,
|
||||
state.current_stage,
|
||||
state.started_at,
|
||||
state.completed_at,
|
||||
state.function_qualified_name,
|
||||
state.file_path,
|
||||
json.dumps(state.stages_completed),
|
||||
state.error_message,
|
||||
json.dumps(state.metadata),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_pipeline_state(self, pipeline_id: str) -> PipelineState | None:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM pipeline_states WHERE pipeline_id = ?", (pipeline_id,))
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return PipelineState(
|
||||
pipeline_id=row["pipeline_id"],
|
||||
status=row["status"],
|
||||
current_stage=row["current_stage"],
|
||||
started_at=row["started_at"],
|
||||
completed_at=row["completed_at"],
|
||||
function_qualified_name=row["function_qualified_name"],
|
||||
file_path=row["file_path"],
|
||||
stages_completed=json.loads(row["stages_completed"]) if row["stages_completed"] else [],
|
||||
error_message=row["error_message"],
|
||||
metadata=json.loads(row["metadata"]) if row["metadata"] else {},
|
||||
)
|
||||
|
||||
def cleanup_old_records(self, days: int = 30) -> int:
|
||||
cutoff_time = time.time() - (days * 24 * 60 * 60)
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("DELETE FROM optimization_attempts WHERE started_at < ?", (cutoff_time,))
|
||||
deleted_attempts = cursor.rowcount
|
||||
|
||||
cursor.execute("DELETE FROM pipeline_states WHERE started_at < ?", (cutoff_time,))
|
||||
deleted_pipelines = cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
return deleted_attempts + deleted_pipelines
|
||||
|
||||
def _row_to_optimization_attempt(self, row: sqlite3.Row) -> OptimizationAttempt:
|
||||
return OptimizationAttempt(
|
||||
attempt_id=row["attempt_id"],
|
||||
function_qualified_name=row["function_qualified_name"],
|
||||
file_path=row["file_path"],
|
||||
status=OptimizationStatus(row["status"]),
|
||||
started_at=row["started_at"],
|
||||
completed_at=row["completed_at"],
|
||||
speedup=row["speedup"],
|
||||
original_runtime_ns=row["original_runtime_ns"],
|
||||
optimized_runtime_ns=row["optimized_runtime_ns"],
|
||||
error_message=row["error_message"],
|
||||
pr_url=row["pr_url"],
|
||||
code_hash=row["code_hash"],
|
||||
metadata=json.loads(row["metadata"]) if row["metadata"] else {},
|
||||
)
|
||||
0
tests/test_agents/__init__.py
Normal file
0
tests/test_agents/__init__.py
Normal file
181
tests/test_agents/test_base.py
Normal file
181
tests/test_agents/test_base.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.agents.base import (
|
||||
AgentMessage,
|
||||
AgentResult,
|
||||
AgentState,
|
||||
AgentTask,
|
||||
BaseAgent,
|
||||
wrap_error,
|
||||
wrap_result,
|
||||
)
|
||||
from codeflash.either import Result
|
||||
|
||||
|
||||
class MockAgent(BaseAgent):
|
||||
def __init__(self, agent_id: str = "mock", should_fail: bool = False) -> None:
|
||||
super().__init__(agent_id=agent_id, coordinator=None)
|
||||
self.should_fail = should_fail
|
||||
self.process_called = False
|
||||
|
||||
def process(self, task: AgentTask) -> Result:
|
||||
self.process_called = True
|
||||
if self.should_fail:
|
||||
return wrap_error("Mock failure", task.task_id, self.agent_id)
|
||||
return wrap_result({"mock_data": "test"}, task.task_id, self.agent_id)
|
||||
|
||||
|
||||
class TestAgentState:
|
||||
def test_agent_states(self) -> None:
|
||||
assert AgentState.IDLE.value == "idle"
|
||||
assert AgentState.RUNNING.value == "running"
|
||||
assert AgentState.WAITING.value == "waiting"
|
||||
assert AgentState.COMPLETED.value == "completed"
|
||||
assert AgentState.FAILED.value == "failed"
|
||||
|
||||
|
||||
class TestAgentTask:
|
||||
def test_create_task(self) -> None:
|
||||
task = AgentTask.create(
|
||||
task_type="test_task",
|
||||
payload={"key": "value"},
|
||||
priority=5,
|
||||
)
|
||||
assert task.task_type == "test_task"
|
||||
assert task.payload == {"key": "value"}
|
||||
assert task.priority == 5
|
||||
assert task.task_id is not None
|
||||
assert task.parent_task_id is None
|
||||
|
||||
def test_task_with_parent(self) -> None:
|
||||
task = AgentTask.create(
|
||||
task_type="child_task",
|
||||
payload={},
|
||||
parent_task_id="parent-123",
|
||||
)
|
||||
assert task.parent_task_id == "parent-123"
|
||||
|
||||
def test_task_comparison(self) -> None:
|
||||
high_priority = AgentTask.create("high", {}, priority=10)
|
||||
low_priority = AgentTask.create("low", {}, priority=1)
|
||||
assert high_priority < low_priority
|
||||
|
||||
|
||||
class TestAgentMessage:
|
||||
def test_create_message(self) -> None:
|
||||
message = AgentMessage.create(
|
||||
sender_id="agent-1",
|
||||
message_type="notification",
|
||||
payload={"status": "ready"},
|
||||
)
|
||||
assert message.sender_id == "agent-1"
|
||||
assert message.message_type == "notification"
|
||||
assert message.payload == {"status": "ready"}
|
||||
assert message.message_id is not None
|
||||
assert message.timestamp > 0
|
||||
|
||||
|
||||
class TestAgentResult:
|
||||
def test_success_result(self) -> None:
|
||||
result = AgentResult.success_result(
|
||||
task_id="task-1",
|
||||
agent_id="agent-1",
|
||||
data={"result": "success"},
|
||||
)
|
||||
assert result.success is True
|
||||
assert result.data == {"result": "success"}
|
||||
assert result.error is None
|
||||
|
||||
def test_failure_result(self) -> None:
|
||||
result = AgentResult.failure_result(
|
||||
task_id="task-1",
|
||||
agent_id="agent-1",
|
||||
error="Something went wrong",
|
||||
)
|
||||
assert result.success is False
|
||||
assert result.data is None
|
||||
assert result.error == "Something went wrong"
|
||||
|
||||
def test_result_with_metadata(self) -> None:
|
||||
result = AgentResult.success_result(
|
||||
task_id="task-1",
|
||||
agent_id="agent-1",
|
||||
data={},
|
||||
metadata={"duration_ms": 100},
|
||||
)
|
||||
assert result.metadata == {"duration_ms": 100}
|
||||
|
||||
|
||||
class TestBaseAgent:
|
||||
def test_agent_initialization(self) -> None:
|
||||
agent = MockAgent("test-agent")
|
||||
assert agent.agent_id == "test-agent"
|
||||
assert agent.state == AgentState.IDLE
|
||||
assert agent.coordinator is None
|
||||
assert agent.current_task is None
|
||||
|
||||
def test_agent_execute_success(self) -> None:
|
||||
agent = MockAgent("test-agent")
|
||||
task = AgentTask.create("test", {"data": "value"})
|
||||
|
||||
result = agent.execute(task)
|
||||
|
||||
assert result.is_successful()
|
||||
assert agent.state == AgentState.COMPLETED
|
||||
assert agent.process_called
|
||||
|
||||
def test_agent_execute_failure(self) -> None:
|
||||
agent = MockAgent("test-agent", should_fail=True)
|
||||
task = AgentTask.create("test", {})
|
||||
|
||||
result = agent.execute(task)
|
||||
|
||||
assert result.is_failure()
|
||||
assert agent.state == AgentState.FAILED
|
||||
|
||||
def test_agent_messaging(self) -> None:
|
||||
agent = MockAgent()
|
||||
|
||||
assert not agent.has_pending_messages()
|
||||
|
||||
message = AgentMessage.create("sender", "test", {})
|
||||
agent.receive_message(message)
|
||||
|
||||
assert agent.has_pending_messages()
|
||||
|
||||
received = agent.get_next_message()
|
||||
assert received == message
|
||||
assert not agent.has_pending_messages()
|
||||
|
||||
def test_agent_reset(self) -> None:
|
||||
agent = MockAgent()
|
||||
agent.state = AgentState.RUNNING
|
||||
agent.receive_message(AgentMessage.create("sender", "test", {}))
|
||||
|
||||
agent.reset()
|
||||
|
||||
assert agent.state == AgentState.IDLE
|
||||
assert not agent.has_pending_messages()
|
||||
|
||||
def test_get_state_snapshot(self) -> None:
|
||||
agent = MockAgent("snapshot-agent")
|
||||
snapshot = agent.get_state_snapshot()
|
||||
|
||||
assert snapshot["agent_id"] == "snapshot-agent"
|
||||
assert snapshot["state"] == "idle"
|
||||
assert snapshot["current_task"] is None
|
||||
|
||||
|
||||
class TestWrapFunctions:
|
||||
def test_wrap_result(self) -> None:
|
||||
result = wrap_result({"data": "test"}, "task-1", "agent-1")
|
||||
assert result.is_successful()
|
||||
agent_result = result.unwrap()
|
||||
assert agent_result.data == {"data": "test"}
|
||||
|
||||
def test_wrap_error(self) -> None:
|
||||
result = wrap_error("error message", "task-1", "agent-1")
|
||||
assert result.is_failure()
|
||||
assert result.failure() == "error message"
|
||||
112
tests/test_agents/test_coordinator.py
Normal file
112
tests/test_agents/test_coordinator.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.agents.base import AgentResult, AgentTask
|
||||
from codeflash.agents.coordinator import AgentCoordinator
|
||||
from codeflash.either import Failure, Success
|
||||
from codeflash.state.store import StateStore
|
||||
|
||||
|
||||
class TestAgentCoordinator:
|
||||
@pytest.fixture
|
||||
def coordinator(self) -> AgentCoordinator:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
store = StateStore(Path(tmpdir))
|
||||
coord = AgentCoordinator(state_store=store)
|
||||
yield coord
|
||||
|
||||
def test_coordinator_initialization(self, coordinator: AgentCoordinator) -> None:
|
||||
assert coordinator.agents == {}
|
||||
assert coordinator.state_store is not None
|
||||
assert not coordinator._initialized
|
||||
|
||||
def test_register_agents(self, coordinator: AgentCoordinator) -> None:
|
||||
coordinator.register_agents()
|
||||
|
||||
assert "discovery" in coordinator.agents
|
||||
assert "analysis" in coordinator.agents
|
||||
assert "generation" in coordinator.agents
|
||||
assert "verification" in coordinator.agents
|
||||
assert "selection" in coordinator.agents
|
||||
assert "integration" in coordinator.agents
|
||||
assert coordinator._initialized
|
||||
|
||||
def test_get_agent(self, coordinator: AgentCoordinator) -> None:
|
||||
coordinator.register_agents()
|
||||
|
||||
agent = coordinator.get_agent("discovery")
|
||||
assert agent is not None
|
||||
assert agent.agent_id == "discovery"
|
||||
|
||||
missing = coordinator.get_agent("nonexistent")
|
||||
assert missing is None
|
||||
|
||||
def test_submit_task(self, coordinator: AgentCoordinator) -> None:
|
||||
task = AgentTask.create("test", {"data": "value"})
|
||||
coordinator.submit_task(task)
|
||||
|
||||
assert not coordinator.task_queue.empty()
|
||||
|
||||
def test_execute_task_agent_not_found(self, coordinator: AgentCoordinator) -> None:
|
||||
task = AgentTask.create("test", {})
|
||||
result = coordinator.execute_task("nonexistent", task)
|
||||
|
||||
assert result.is_failure()
|
||||
assert "not found" in result.failure()
|
||||
|
||||
def test_persist_agent_state(self, coordinator: AgentCoordinator) -> None:
|
||||
coordinator.register_agents()
|
||||
|
||||
state_snapshot = {
|
||||
"state": "running",
|
||||
"current_task": "task-123",
|
||||
}
|
||||
coordinator.persist_agent_state("discovery", state_snapshot)
|
||||
|
||||
retrieved = coordinator.state_store.get_agent_state("discovery")
|
||||
assert retrieved is not None
|
||||
assert retrieved.state == "running"
|
||||
|
||||
def test_notify_failure_logs_warning(self, coordinator: AgentCoordinator) -> None:
|
||||
with patch("codeflash.agents.coordinator.logger") as mock_logger:
|
||||
coordinator.notify_failure("test-agent", "Test error message")
|
||||
mock_logger.warning.assert_called_once()
|
||||
|
||||
def test_reset_all_agents(self, coordinator: AgentCoordinator) -> None:
|
||||
coordinator.register_agents()
|
||||
task = AgentTask.create("test", {})
|
||||
coordinator.submit_task(task)
|
||||
|
||||
coordinator.reset_all_agents()
|
||||
|
||||
assert coordinator.task_queue.empty()
|
||||
for agent in coordinator.agents.values():
|
||||
from codeflash.agents.base import AgentState
|
||||
|
||||
assert agent.state == AgentState.IDLE
|
||||
|
||||
|
||||
class TestCoordinatorExecution:
|
||||
@pytest.fixture
|
||||
def coordinator_with_agents(self) -> AgentCoordinator:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
store = StateStore(Path(tmpdir))
|
||||
coord = AgentCoordinator(state_store=store)
|
||||
coord.register_agents()
|
||||
yield coord
|
||||
|
||||
def test_execute_task_on_registered_agent(self, coordinator_with_agents: AgentCoordinator) -> None:
|
||||
from codeflash.agents.analysis_agent import create_numerical_check_task
|
||||
|
||||
task = create_numerical_check_task("def foo(): return 1")
|
||||
result = coordinator_with_agents.execute_task("analysis", task)
|
||||
|
||||
assert result.is_successful()
|
||||
agent_result = result.unwrap()
|
||||
assert agent_result.success
|
||||
assert "is_numerical" in agent_result.data
|
||||
220
tests/test_agents/test_state.py
Normal file
220
tests/test_agents/test_state.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.state.history import OptimizationHistory
|
||||
from codeflash.state.models import AgentStateSnapshot, OptimizationAttempt, OptimizationStatus, PipelineState
|
||||
from codeflash.state.store import StateStore
|
||||
|
||||
|
||||
class TestOptimizationStatus:
|
||||
def test_status_values(self) -> None:
|
||||
assert OptimizationStatus.PENDING.value == "pending"
|
||||
assert OptimizationStatus.IN_PROGRESS.value == "in_progress"
|
||||
assert OptimizationStatus.COMPLETED.value == "completed"
|
||||
assert OptimizationStatus.FAILED.value == "failed"
|
||||
assert OptimizationStatus.SKIPPED.value == "skipped"
|
||||
|
||||
|
||||
class TestOptimizationAttempt:
|
||||
def test_create_attempt(self) -> None:
|
||||
attempt = OptimizationAttempt.create(
|
||||
attempt_id="attempt-1",
|
||||
function_qualified_name="module.function",
|
||||
file_path="/path/to/file.py",
|
||||
code_hash="abc123",
|
||||
)
|
||||
assert attempt.attempt_id == "attempt-1"
|
||||
assert attempt.function_qualified_name == "module.function"
|
||||
assert attempt.file_path == "/path/to/file.py"
|
||||
assert attempt.status == OptimizationStatus.PENDING
|
||||
assert attempt.code_hash == "abc123"
|
||||
|
||||
def test_mark_in_progress(self) -> None:
|
||||
attempt = OptimizationAttempt.create("id", "func", "/file.py")
|
||||
in_progress = attempt.mark_in_progress()
|
||||
|
||||
assert in_progress.status == OptimizationStatus.IN_PROGRESS
|
||||
assert in_progress.attempt_id == attempt.attempt_id
|
||||
|
||||
def test_mark_completed(self) -> None:
|
||||
attempt = OptimizationAttempt.create("id", "func", "/file.py")
|
||||
completed = attempt.mark_completed(
|
||||
speedup=2.5,
|
||||
original_runtime_ns=1000000,
|
||||
optimized_runtime_ns=400000,
|
||||
pr_url="https://github.com/test/pr/1",
|
||||
)
|
||||
|
||||
assert completed.status == OptimizationStatus.COMPLETED
|
||||
assert completed.speedup == 2.5
|
||||
assert completed.original_runtime_ns == 1000000
|
||||
assert completed.optimized_runtime_ns == 400000
|
||||
assert completed.pr_url == "https://github.com/test/pr/1"
|
||||
assert completed.completed_at is not None
|
||||
|
||||
def test_mark_failed(self) -> None:
|
||||
attempt = OptimizationAttempt.create("id", "func", "/file.py")
|
||||
failed = attempt.mark_failed("Test failure message")
|
||||
|
||||
assert failed.status == OptimizationStatus.FAILED
|
||||
assert failed.error_message == "Test failure message"
|
||||
assert failed.completed_at is not None
|
||||
|
||||
def test_mark_skipped(self) -> None:
|
||||
attempt = OptimizationAttempt.create("id", "func", "/file.py")
|
||||
skipped = attempt.mark_skipped("Already optimized")
|
||||
|
||||
assert skipped.status == OptimizationStatus.SKIPPED
|
||||
assert skipped.error_message == "Already optimized"
|
||||
|
||||
|
||||
class TestAgentStateSnapshot:
|
||||
def test_create_snapshot(self) -> None:
|
||||
snapshot = AgentStateSnapshot.create(
|
||||
agent_id="discovery",
|
||||
agent_type="DiscoveryAgent",
|
||||
state="running",
|
||||
current_task_id="task-123",
|
||||
context={"extra": "data"},
|
||||
)
|
||||
|
||||
assert snapshot.agent_id == "discovery"
|
||||
assert snapshot.agent_type == "DiscoveryAgent"
|
||||
assert snapshot.state == "running"
|
||||
assert snapshot.current_task_id == "task-123"
|
||||
assert snapshot.context == {"extra": "data"}
|
||||
|
||||
|
||||
class TestPipelineState:
|
||||
def test_create_pipeline_state(self) -> None:
|
||||
state = PipelineState.create(
|
||||
pipeline_id="pipeline-1",
|
||||
function_qualified_name="module.func",
|
||||
file_path="/path/file.py",
|
||||
)
|
||||
|
||||
assert state.pipeline_id == "pipeline-1"
|
||||
assert state.status == "pending"
|
||||
assert state.function_qualified_name == "module.func"
|
||||
|
||||
|
||||
class TestStateStore:
|
||||
@pytest.fixture
|
||||
def temp_store(self) -> StateStore:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield StateStore(Path(tmpdir))
|
||||
|
||||
def test_store_initialization(self, temp_store: StateStore) -> None:
|
||||
assert temp_store.db_path.exists()
|
||||
|
||||
def test_persist_and_get_optimization_attempt(self, temp_store: StateStore) -> None:
|
||||
attempt = OptimizationAttempt.create("test-id", "func.name", "/file.py")
|
||||
temp_store.persist_optimization_attempt(attempt)
|
||||
|
||||
retrieved = temp_store.get_optimization_attempt("test-id")
|
||||
assert retrieved is not None
|
||||
assert retrieved.attempt_id == "test-id"
|
||||
assert retrieved.function_qualified_name == "func.name"
|
||||
|
||||
def test_get_function_history(self, temp_store: StateStore) -> None:
|
||||
for i in range(3):
|
||||
attempt = OptimizationAttempt.create(f"id-{i}", "module.func", "/file.py")
|
||||
temp_store.persist_optimization_attempt(attempt)
|
||||
|
||||
history = temp_store.get_function_history("module.func")
|
||||
assert len(history) == 3
|
||||
|
||||
def test_get_recent_attempts(self, temp_store: StateStore) -> None:
|
||||
attempt1 = OptimizationAttempt.create("id-1", "func1", "/file1.py")
|
||||
attempt2 = OptimizationAttempt.create("id-2", "func2", "/file2.py")
|
||||
|
||||
temp_store.persist_optimization_attempt(attempt1)
|
||||
temp_store.persist_optimization_attempt(attempt2)
|
||||
|
||||
recent = temp_store.get_recent_attempts(limit=10)
|
||||
assert len(recent) == 2
|
||||
|
||||
def test_was_function_recently_optimized(self, temp_store: StateStore) -> None:
|
||||
attempt = OptimizationAttempt.create("id", "func", "/file.py", code_hash="hash123")
|
||||
completed = attempt.mark_completed(2.0, 1000, 500)
|
||||
temp_store.persist_optimization_attempt(completed)
|
||||
|
||||
assert temp_store.was_function_recently_optimized("func", code_hash="hash123")
|
||||
assert not temp_store.was_function_recently_optimized("other_func")
|
||||
|
||||
def test_persist_and_get_agent_state(self, temp_store: StateStore) -> None:
|
||||
snapshot = AgentStateSnapshot.create("agent-1", "TestAgent", "running")
|
||||
temp_store.persist_agent_state(snapshot)
|
||||
|
||||
retrieved = temp_store.get_agent_state("agent-1")
|
||||
assert retrieved is not None
|
||||
assert retrieved.agent_type == "TestAgent"
|
||||
|
||||
def test_persist_and_get_pipeline_state(self, temp_store: StateStore) -> None:
|
||||
state = PipelineState.create("pipe-1", "func", "/file.py")
|
||||
temp_store.persist_pipeline_state(state)
|
||||
|
||||
retrieved = temp_store.get_pipeline_state("pipe-1")
|
||||
assert retrieved is not None
|
||||
assert retrieved.function_qualified_name == "func"
|
||||
|
||||
|
||||
class TestOptimizationHistory:
|
||||
@pytest.fixture
|
||||
def history(self) -> OptimizationHistory:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
store = StateStore(Path(tmpdir))
|
||||
yield OptimizationHistory(store)
|
||||
|
||||
def test_record_and_get_attempts(self, history: OptimizationHistory) -> None:
|
||||
attempt = OptimizationAttempt.create("id", "func", "/file.py")
|
||||
history.record_attempt(attempt)
|
||||
|
||||
attempts = history.get_function_attempts("func")
|
||||
assert len(attempts) == 1
|
||||
|
||||
def test_should_skip_recently_optimized(self, history: OptimizationHistory) -> None:
|
||||
for i in range(3):
|
||||
attempt = OptimizationAttempt.create(f"id-{i}", "func", "/file.py")
|
||||
completed = attempt.mark_completed(2.0, 1000, 500)
|
||||
history.record_attempt(completed)
|
||||
|
||||
should_skip, reason = history.should_skip_function("func")
|
||||
assert should_skip
|
||||
assert "successfully optimized" in reason
|
||||
|
||||
def test_should_skip_repeatedly_failed(self, history: OptimizationHistory) -> None:
|
||||
for i in range(3):
|
||||
attempt = OptimizationAttempt.create(f"id-{i}", "func", "/file.py")
|
||||
failed = attempt.mark_failed("error")
|
||||
history.record_attempt(failed)
|
||||
|
||||
should_skip, reason = history.should_skip_function("func")
|
||||
assert should_skip
|
||||
assert "failed" in reason
|
||||
|
||||
def test_get_best_speedup(self, history: OptimizationHistory) -> None:
|
||||
for speedup in [1.5, 2.0, 1.8]:
|
||||
attempt = OptimizationAttempt.create(f"id-{speedup}", "func", "/file.py")
|
||||
completed = attempt.mark_completed(speedup, 1000, int(1000 / speedup))
|
||||
history.record_attempt(completed)
|
||||
|
||||
best = history.get_best_speedup("func")
|
||||
assert best == 2.0
|
||||
|
||||
def test_get_statistics(self, history: OptimizationHistory) -> None:
|
||||
attempt1 = OptimizationAttempt.create("id-1", "func1", "/file.py")
|
||||
attempt2 = OptimizationAttempt.create("id-2", "func2", "/file.py")
|
||||
|
||||
history.record_attempt(attempt1.mark_completed(2.0, 1000, 500))
|
||||
history.record_attempt(attempt2.mark_failed("error"))
|
||||
|
||||
stats = history.get_statistics()
|
||||
assert stats["total_attempts"] == 2
|
||||
assert stats["completed"] == 1
|
||||
assert stats["failed"] == 1
|
||||
Loading…
Reference in a new issue