feat: add agentic multi-agent optimization pipeline

Implement a new agentic workflow for CodeFlash with a --agentic CLI flag.
The pipeline uses specialized agents (Discovery, Analysis, Generation,
Verification, Selection, Integration) orchestrated by an AgentCoordinator.

New components:
- codeflash/agents/: Agent framework with BaseAgent and specialized agents
- codeflash/state/: SQLite-based state management for persistence
- codeflash/optimization/agentic_optimizer.py: Entry point for agentic mode

Usage: codeflash --agentic --file path/to/file.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
aseembits93 2026-02-04 15:40:14 -08:00
parent 6fc2c177f2
commit 84d954336f
20 changed files with 2970 additions and 2 deletions

View file

@ -0,0 +1,25 @@
from __future__ import annotations
from codeflash.agents.analysis_agent import AnalysisAgent
from codeflash.agents.base import AgentMessage, AgentResult, AgentState, AgentTask, BaseAgent
from codeflash.agents.coordinator import AgentCoordinator
from codeflash.agents.discovery_agent import DiscoveryAgent
from codeflash.agents.generation_agent import GenerationAgent
from codeflash.agents.integration_agent import IntegrationAgent
from codeflash.agents.selection_agent import SelectionAgent
from codeflash.agents.verification_agent import VerificationAgent
__all__ = [
"AgentCoordinator",
"AgentMessage",
"AgentResult",
"AgentState",
"AgentTask",
"AnalysisAgent",
"BaseAgent",
"DiscoveryAgent",
"GenerationAgent",
"IntegrationAgent",
"SelectionAgent",
"VerificationAgent",
]

View file

@ -0,0 +1,147 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Any
from codeflash.agents.base import AgentResult, AgentTask, BaseAgent, wrap_error, wrap_result
from codeflash.cli_cmds.console import logger
from codeflash.either import Result
if TYPE_CHECKING:
from codeflash.agents.coordinator import AgentCoordinator
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext
class AnalysisAgent(BaseAgent):
def __init__(self, coordinator: AgentCoordinator | None = None) -> None:
super().__init__(agent_id="analysis", coordinator=coordinator)
def process(self, task: AgentTask) -> Result[AgentResult[Any], str]:
task_type = task.task_type
if task_type == "extract_context":
return self._extract_context(task)
if task_type == "assess_complexity":
return self._assess_complexity(task)
if task_type == "check_numerical":
return self._check_numerical(task)
return wrap_error(f"Unknown task type: {task_type}", task.task_id, self.agent_id)
def _extract_context(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.context.code_context_extractor import get_code_optimization_context
function_to_optimize: FunctionToOptimize = payload["function_to_optimize"]
project_root: Path = payload["project_root"]
code_context = get_code_optimization_context(function_to_optimize, project_root)
result_data = {
"code_context": code_context,
"function_to_optimize": function_to_optimize,
}
logger.debug(f"AnalysisAgent: Extracted context for {function_to_optimize.qualified_name}")
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"AnalysisAgent context extraction failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def _assess_complexity(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
code_context: CodeOptimizationContext = payload["code_context"]
complexity_info = self._analyze_code_complexity(code_context)
result_data = {
"complexity": complexity_info,
"code_context": code_context,
}
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
return wrap_error(str(e), task.task_id, self.agent_id)
def _check_numerical(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.code_utils.code_extractor import is_numerical_code
source_code: str = payload["source_code"]
is_numerical = is_numerical_code(source_code)
result_data = {
"is_numerical": is_numerical,
}
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
return wrap_error(str(e), task.task_id, self.agent_id)
def _analyze_code_complexity(self, code_context: CodeOptimizationContext) -> dict[str, Any]:
from codeflash.code_utils.code_utils import encoded_tokens_len
read_writable_tokens = encoded_tokens_len(code_context.read_writable_code.markdown)
read_only_tokens = (
encoded_tokens_len(code_context.read_only_context_code) if code_context.read_only_context_code else 0
)
helper_count = len(code_context.helper_functions)
file_count = len(code_context.read_writable_code.code_strings)
return {
"total_tokens": read_writable_tokens + read_only_tokens,
"read_writable_tokens": read_writable_tokens,
"read_only_tokens": read_only_tokens,
"helper_count": helper_count,
"file_count": file_count,
"complexity_level": self._determine_complexity_level(
read_writable_tokens + read_only_tokens, helper_count, file_count
),
}
def _determine_complexity_level(self, total_tokens: int, helper_count: int, file_count: int) -> str:
if total_tokens > 10000 or helper_count > 10 or file_count > 5:
return "high"
if total_tokens > 5000 or helper_count > 5 or file_count > 2:
return "medium"
return "low"
def create_context_extraction_task(
function_to_optimize: Any,
project_root: Path,
) -> AgentTask:
return AgentTask.create(
task_type="extract_context",
payload={
"function_to_optimize": function_to_optimize,
"project_root": project_root,
},
priority=8,
)
def create_complexity_assessment_task(code_context: Any) -> AgentTask:
return AgentTask.create(
task_type="assess_complexity",
payload={
"code_context": code_context,
},
priority=7,
)
def create_numerical_check_task(source_code: str) -> AgentTask:
return AgentTask.create(
task_type="check_numerical",
payload={
"source_code": source_code,
},
priority=7,
)

193
codeflash/agents/base.py Normal file
View file

@ -0,0 +1,193 @@
from __future__ import annotations
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from queue import Queue
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from codeflash.either import Failure, Result, Success
if TYPE_CHECKING:
from codeflash.agents.coordinator import AgentCoordinator
class AgentState(Enum):
IDLE = "idle"
RUNNING = "running"
WAITING = "waiting"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class AgentTask:
task_id: str
task_type: str
payload: dict[str, Any]
priority: int = 0
parent_task_id: str | None = None
@classmethod
def create(
cls,
task_type: str,
payload: dict[str, Any],
priority: int = 0,
parent_task_id: str | None = None,
) -> AgentTask:
return cls(
task_id=str(uuid.uuid4()),
task_type=task_type,
payload=payload,
priority=priority,
parent_task_id=parent_task_id,
)
def __lt__(self, other: AgentTask) -> bool:
return self.priority > other.priority
@dataclass
class AgentMessage:
message_id: str
sender_id: str
message_type: str
payload: dict[str, Any]
timestamp: float = field(default_factory=lambda: __import__("time").time())
@classmethod
def create(
cls,
sender_id: str,
message_type: str,
payload: dict[str, Any],
) -> AgentMessage:
return cls(
message_id=str(uuid.uuid4()),
sender_id=sender_id,
message_type=message_type,
payload=payload,
)
T = TypeVar("T")
@dataclass
class AgentResult(Generic[T]):
task_id: str
agent_id: str
success: bool
data: T | None = None
error: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@classmethod
def success_result(
cls,
task_id: str,
agent_id: str,
data: T,
metadata: dict[str, Any] | None = None,
) -> AgentResult[T]:
return cls(
task_id=task_id,
agent_id=agent_id,
success=True,
data=data,
metadata=metadata or {},
)
@classmethod
def failure_result(
cls,
task_id: str,
agent_id: str,
error: str,
metadata: dict[str, Any] | None = None,
) -> AgentResult[T]:
return cls(
task_id=task_id,
agent_id=agent_id,
success=False,
error=error,
metadata=metadata or {},
)
class BaseAgent(ABC):
def __init__(self, agent_id: str, coordinator: AgentCoordinator | None = None) -> None:
self.agent_id = agent_id
self.state = AgentState.IDLE
self.coordinator = coordinator
self.inbox: Queue[AgentMessage] = Queue()
self.current_task: AgentTask | None = None
@property
def name(self) -> str:
return self.__class__.__name__
def set_coordinator(self, coordinator: AgentCoordinator) -> None:
self.coordinator = coordinator
@abstractmethod
def process(self, task: AgentTask) -> Result[AgentResult[Any], str]:
pass
def execute(self, task: AgentTask) -> Result[AgentResult[Any], str]:
self.state = AgentState.RUNNING
self.current_task = task
try:
result = self.process(task)
if result.is_successful():
self.state = AgentState.COMPLETED
else:
self.state = AgentState.FAILED
return result
except Exception as e:
self.state = AgentState.FAILED
self.handle_failure(str(e))
return Failure(f"Agent {self.agent_id} failed: {e}")
finally:
self.current_task = None
def handle_failure(self, error: str) -> None:
if self.coordinator:
self.coordinator.persist_agent_state(self.agent_id, self.get_state_snapshot())
self.coordinator.notify_failure(self.agent_id, error)
def get_state_snapshot(self) -> dict[str, Any]:
return {
"agent_id": self.agent_id,
"state": self.state.value,
"current_task": self.current_task.task_id if self.current_task else None,
"inbox_size": self.inbox.qsize(),
}
def receive_message(self, message: AgentMessage) -> None:
self.inbox.put(message)
def has_pending_messages(self) -> bool:
return not self.inbox.empty()
def get_next_message(self) -> AgentMessage | None:
if self.inbox.empty():
return None
return self.inbox.get_nowait()
def reset(self) -> None:
self.state = AgentState.IDLE
self.current_task = None
while not self.inbox.empty():
self.inbox.get_nowait()
def wrap_result(data: Any, task_id: str, agent_id: str) -> Result[AgentResult[Any], str]:
return Success(AgentResult.success_result(task_id=task_id, agent_id=agent_id, data=data))
def wrap_error(error: str, task_id: str, agent_id: str) -> Result[AgentResult[Any], str]:
return Failure(error)

View file

@ -0,0 +1,284 @@
from __future__ import annotations
import uuid
from pathlib import Path
from queue import PriorityQueue
from typing import TYPE_CHECKING, Any
from codeflash.agents.base import AgentResult, AgentTask, BaseAgent
from codeflash.cli_cmds.console import console, logger
from codeflash.either import Failure, Result, Success
from codeflash.state.models import AgentStateSnapshot, OptimizationAttempt, PipelineState
from codeflash.state.store import StateStore
if TYPE_CHECKING:
from argparse import Namespace
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import BestOptimization, CodeOptimizationContext
from codeflash.verification.verification_utils import TestConfig
class AgentCoordinator:
def __init__(self, state_store: StateStore | None = None) -> None:
self.agents: dict[str, BaseAgent] = {}
self.state_store = state_store or StateStore()
self.task_queue: PriorityQueue[AgentTask] = PriorityQueue()
self._initialized = False
def register_agents(self) -> None:
from codeflash.agents.analysis_agent import AnalysisAgent
from codeflash.agents.discovery_agent import DiscoveryAgent
from codeflash.agents.generation_agent import GenerationAgent
from codeflash.agents.integration_agent import IntegrationAgent
from codeflash.agents.selection_agent import SelectionAgent
from codeflash.agents.verification_agent import VerificationAgent
self.agents = {
"discovery": DiscoveryAgent(coordinator=self, state_store=self.state_store),
"analysis": AnalysisAgent(coordinator=self),
"generation": GenerationAgent(coordinator=self),
"verification": VerificationAgent(coordinator=self),
"selection": SelectionAgent(coordinator=self),
"integration": IntegrationAgent(coordinator=self),
}
self._initialized = True
def get_agent(self, agent_id: str) -> BaseAgent | None:
return self.agents.get(agent_id)
def submit_task(self, task: AgentTask) -> None:
self.task_queue.put(task)
def execute_task(self, agent_id: str, task: AgentTask) -> Result[AgentResult[Any], str]:
agent = self.agents.get(agent_id)
if agent is None:
return Failure(f"Agent {agent_id} not found")
logger.debug(f"Coordinator: Executing task {task.task_type} on agent {agent_id}")
return agent.execute(task)
def persist_agent_state(self, agent_id: str, state_snapshot: dict[str, Any]) -> None:
agent = self.agents.get(agent_id)
if agent is None:
return
snapshot = AgentStateSnapshot.create(
agent_id=agent_id,
agent_type=agent.name,
state=state_snapshot.get("state", "unknown"),
current_task_id=state_snapshot.get("current_task"),
context=state_snapshot,
)
self.state_store.persist_agent_state(snapshot)
def notify_failure(self, agent_id: str, error: str) -> None:
logger.warning(f"Agent {agent_id} failed: {error}")
def run_optimization_pipeline(
self,
function_to_optimize: FunctionToOptimize,
test_cfg: TestConfig,
args: Namespace,
function_to_tests: dict[str, set[Any]] | None = None,
) -> Result[BestOptimization, str]:
if not self._initialized:
self.register_agents()
pipeline_id = str(uuid.uuid4())
trace_id = str(uuid.uuid4())
pipeline_state = PipelineState.create(
pipeline_id=pipeline_id,
function_qualified_name=function_to_optimize.qualified_name,
file_path=function_to_optimize.file_path,
)
self.state_store.persist_pipeline_state(pipeline_state)
attempt = OptimizationAttempt.create(
attempt_id=pipeline_id,
function_qualified_name=function_to_optimize.qualified_name,
file_path=function_to_optimize.file_path,
)
self.state_store.persist_optimization_attempt(attempt.mark_in_progress())
logger.info(f"Starting agentic optimization pipeline for {function_to_optimize.qualified_name}")
console.rule()
try:
context_result = self._run_analysis_stage(function_to_optimize, args.project_root)
if context_result.is_failure():
return self._handle_pipeline_failure(attempt, context_result.failure())
code_context = context_result.unwrap()
candidates_result = self._run_generation_stage(code_context, trace_id, function_to_optimize)
if candidates_result.is_failure():
return self._handle_pipeline_failure(attempt, candidates_result.failure())
candidates = candidates_result.unwrap()
if not candidates:
return self._handle_pipeline_failure(attempt, "No optimization candidates generated")
best_result = self._run_verification_and_selection_stage(
candidates=candidates,
code_context=code_context,
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
args=args,
function_to_tests=function_to_tests,
trace_id=trace_id,
)
if best_result.is_failure():
return self._handle_pipeline_failure(attempt, best_result.failure())
best_optimization = best_result.unwrap()
integration_result = self._run_integration_stage(
best_optimization=best_optimization,
function_to_optimize=function_to_optimize,
args=args,
)
if integration_result.is_failure():
logger.warning(f"Integration stage failed: {integration_result.failure()}")
speedup = 1.0
if best_optimization.runtime > 0:
original_runtime = best_optimization.winning_benchmarking_test_results.get_best_runtime() or 0
if original_runtime > 0:
speedup = original_runtime / best_optimization.runtime
completed_attempt = attempt.mark_completed(
speedup=speedup,
original_runtime_ns=original_runtime,
optimized_runtime_ns=best_optimization.runtime,
pr_url=integration_result.unwrap().get("pr_url") if integration_result.is_successful() else None,
)
self.state_store.persist_optimization_attempt(completed_attempt)
logger.info(f"Agentic optimization complete for {function_to_optimize.qualified_name}")
return Success(best_optimization)
except Exception as e:
logger.exception(f"Pipeline failed: {e}")
return self._handle_pipeline_failure(attempt, str(e))
def _run_analysis_stage(
self,
function_to_optimize: FunctionToOptimize,
project_root: Path,
) -> Result[CodeOptimizationContext, str]:
from codeflash.agents.analysis_agent import create_context_extraction_task
task = create_context_extraction_task(function_to_optimize, project_root)
result = self.execute_task("analysis", task)
if result.is_failure():
return Failure(result.failure())
agent_result = result.unwrap()
if not agent_result.success:
return Failure(agent_result.error or "Analysis failed")
return Success(agent_result.data["code_context"])
def _run_generation_stage(
self,
code_context: CodeOptimizationContext,
trace_id: str,
function_to_optimize: FunctionToOptimize,
) -> Result[list[Any], str]:
from codeflash.agents.generation_agent import create_generation_task
language = function_to_optimize.language or "python"
is_async = function_to_optimize.is_async
task = create_generation_task(
code_context=code_context,
trace_id=trace_id,
is_async=is_async,
language=language,
)
result = self.execute_task("generation", task)
if result.is_failure():
return Failure(result.failure())
agent_result = result.unwrap()
if not agent_result.success:
return Failure(agent_result.error or "Generation failed")
return Success(agent_result.data["candidates"])
def _run_verification_and_selection_stage(
self,
candidates: list[Any],
code_context: CodeOptimizationContext,
function_to_optimize: FunctionToOptimize,
test_cfg: TestConfig,
args: Namespace,
function_to_tests: dict[str, set[Any]] | None,
trace_id: str,
) -> Result[BestOptimization, str]:
from codeflash.optimization.function_optimizer import FunctionOptimizer
function_optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=code_context.read_writable_code.code_strings[0].code
if code_context.read_writable_code.code_strings
else "",
function_to_tests=function_to_tests,
function_to_optimize_ast=None,
aiservice_client=self.agents["generation"].aiservice_client
if "generation" in self.agents
else None,
args=args,
)
try:
return function_optimizer.optimize_function()
finally:
function_optimizer.executor.shutdown(wait=True)
function_optimizer.cleanup_generated_files()
def _run_integration_stage(
self,
best_optimization: BestOptimization,
function_to_optimize: FunctionToOptimize,
args: Namespace,
) -> Result[dict[str, Any], str]:
from codeflash.agents.integration_agent import create_pr_task
if getattr(args, "no_pr", False):
return Success({"pr_created": False, "reason": "PR creation disabled"})
task = create_pr_task(
best_optimization=best_optimization,
function_to_optimize=function_to_optimize,
args=args,
)
result = self.execute_task("integration", task)
if result.is_failure():
return Failure(result.failure())
agent_result = result.unwrap()
if not agent_result.success:
return Failure(agent_result.error or "Integration failed")
return Success(agent_result.data)
def _handle_pipeline_failure(
self,
attempt: OptimizationAttempt,
error: str,
) -> Result[BestOptimization, str]:
failed_attempt = attempt.mark_failed(error)
self.state_store.persist_optimization_attempt(failed_attempt)
return Failure(error)
def reset_all_agents(self) -> None:
for agent in self.agents.values():
agent.reset()
while not self.task_queue.empty():
self.task_queue.get_nowait()

View file

@ -0,0 +1,187 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Any
from codeflash.agents.base import AgentResult, AgentTask, BaseAgent, wrap_error, wrap_result
from codeflash.cli_cmds.console import logger
from codeflash.either import Result
if TYPE_CHECKING:
from codeflash.agents.coordinator import AgentCoordinator
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.state.store import StateStore
from codeflash.verification.verification_utils import TestConfig
class DiscoveryAgent(BaseAgent):
def __init__(
self,
coordinator: AgentCoordinator | None = None,
state_store: StateStore | None = None,
) -> None:
super().__init__(agent_id="discovery", coordinator=coordinator)
self.state_store = state_store
def process(self, task: AgentTask) -> Result[AgentResult[Any], str]:
task_type = task.task_type
if task_type == "discover_functions":
return self._discover_functions(task)
if task_type == "discover_tests":
return self._discover_tests(task)
if task_type == "filter_functions":
return self._filter_functions(task)
return wrap_error(f"Unknown task type: {task_type}", task.task_id, self.agent_id)
def _discover_functions(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
test_cfg: TestConfig = payload["test_cfg"]
optimize_all = payload.get("optimize_all")
replay_test = payload.get("replay_test")
file_path = payload.get("file")
function_name = payload.get("function")
ignore_paths = payload.get("ignore_paths", [])
project_root = payload["project_root"]
module_root = payload["module_root"]
previous_checkpoint_functions = payload.get("previous_checkpoint_functions")
file_to_funcs, num_functions, trace_file = get_functions_to_optimize(
optimize_all=optimize_all,
replay_test=replay_test,
file=file_path,
only_get_this_function=function_name,
test_cfg=test_cfg,
ignore_paths=ignore_paths,
project_root=project_root,
module_root=module_root,
previous_checkpoint_functions=previous_checkpoint_functions,
)
if self.state_store:
file_to_funcs = self._filter_by_history(file_to_funcs)
num_functions = sum(len(funcs) for funcs in file_to_funcs.values())
result_data = {
"file_to_funcs": file_to_funcs,
"num_functions": num_functions,
"trace_file": trace_file,
}
logger.info(f"DiscoveryAgent: Found {num_functions} functions to optimize")
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"DiscoveryAgent failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def _discover_tests(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.discovery.discover_unit_tests import discover_unit_tests
test_cfg: TestConfig = payload["test_cfg"]
file_to_funcs: dict[Path, list[FunctionToOptimize]] = payload["file_to_funcs"]
function_to_tests, num_tests, num_replay_tests = discover_unit_tests(
test_cfg, file_to_funcs_to_optimize=file_to_funcs
)
result_data = {
"function_to_tests": function_to_tests,
"num_tests": num_tests,
"num_replay_tests": num_replay_tests,
}
logger.info(f"DiscoveryAgent: Discovered {num_tests} tests and {num_replay_tests} replay tests")
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"DiscoveryAgent test discovery failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def _filter_functions(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
file_to_funcs: dict[Path, list[FunctionToOptimize]] = payload["file_to_funcs"]
filtered = self._filter_by_history(file_to_funcs)
num_functions = sum(len(funcs) for funcs in filtered.values())
result_data = {
"file_to_funcs": filtered,
"num_functions": num_functions,
}
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
return wrap_error(str(e), task.task_id, self.agent_id)
def _filter_by_history(
self, file_to_funcs: dict[Path, list[FunctionToOptimize]]
) -> dict[Path, list[FunctionToOptimize]]:
if not self.state_store:
return file_to_funcs
from codeflash.state.history import OptimizationHistory
history = OptimizationHistory(self.state_store)
filtered: dict[Path, list[FunctionToOptimize]] = {}
for file_path, functions in file_to_funcs.items():
kept_functions = []
for func in functions:
should_skip, reason = history.should_skip_function(func.qualified_name)
if should_skip:
logger.debug(f"Skipping {func.qualified_name}: {reason}")
continue
kept_functions.append(func)
if kept_functions:
filtered[file_path] = kept_functions
return filtered
def create_discovery_task(
test_cfg: Any,
project_root: Path,
module_root: Path,
optimize_all: Path | str | None = None,
replay_test: list[Path] | None = None,
file: Path | None = None,
function: str | None = None,
ignore_paths: list[Path] | None = None,
previous_checkpoint_functions: dict[str, dict[str, str]] | None = None,
) -> AgentTask:
return AgentTask.create(
task_type="discover_functions",
payload={
"test_cfg": test_cfg,
"project_root": project_root,
"module_root": module_root,
"optimize_all": optimize_all,
"replay_test": replay_test,
"file": file,
"function": function,
"ignore_paths": ignore_paths or [],
"previous_checkpoint_functions": previous_checkpoint_functions,
},
priority=10,
)
def create_test_discovery_task(
test_cfg: Any,
file_to_funcs: dict[Path, list[Any]],
) -> AgentTask:
return AgentTask.create(
task_type="discover_tests",
payload={
"test_cfg": test_cfg,
"file_to_funcs": file_to_funcs,
},
priority=9,
)

View file

@ -0,0 +1,197 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from codeflash.agents.base import AgentResult, AgentTask, BaseAgent, wrap_error, wrap_result
from codeflash.cli_cmds.console import logger
from codeflash.either import Result
if TYPE_CHECKING:
from codeflash.agents.coordinator import AgentCoordinator
from codeflash.api.aiservice import AiServiceClient
from codeflash.models.models import CodeOptimizationContext
class GenerationAgent(BaseAgent):
def __init__(
self,
coordinator: AgentCoordinator | None = None,
aiservice_client: AiServiceClient | None = None,
) -> None:
super().__init__(agent_id="generation", coordinator=coordinator)
self._aiservice_client = aiservice_client
@property
def aiservice_client(self) -> AiServiceClient:
if self._aiservice_client is None:
from codeflash.api.aiservice import AiServiceClient
self._aiservice_client = AiServiceClient()
return self._aiservice_client
def process(self, task: AgentTask) -> Result[AgentResult[Any], str]:
task_type = task.task_type
if task_type == "generate_candidates":
return self._generate_candidates(task)
if task_type == "refine_candidate":
return self._refine_candidate(task)
if task_type == "repair_candidate":
return self._repair_candidate(task)
if task_type == "adaptive_optimize":
return self._adaptive_optimize(task)
return wrap_error(f"Unknown task type: {task_type}", task.task_id, self.agent_id)
def _generate_candidates(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
code_context: CodeOptimizationContext = payload["code_context"]
trace_id: str = payload["trace_id"]
is_async: bool = payload.get("is_async", False)
n_candidates: int = payload.get("n_candidates", 5)
is_numerical: bool | None = payload.get("is_numerical")
language: str = payload.get("language", "python")
source_code = code_context.read_writable_code.markdown
dependency_code = code_context.read_only_context_code or ""
candidates = self.aiservice_client.optimize_code(
source_code=source_code,
dependency_code=dependency_code,
trace_id=trace_id,
language=language,
is_async=is_async,
n_candidates=n_candidates,
is_numerical_code=is_numerical,
)
result_data = {
"candidates": candidates,
"candidate_count": len(candidates),
}
logger.info(f"GenerationAgent: Generated {len(candidates)} optimization candidates")
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"GenerationAgent candidate generation failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def _refine_candidate(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.models.models import AIServiceRefinerRequest
refiner_request: AIServiceRefinerRequest = payload["refiner_request"]
trace_id: str = payload["trace_id"]
refined_candidates = self.aiservice_client.refine_code(refiner_request, trace_id)
result_data = {
"refined_candidates": refined_candidates,
"original_optimization_id": refiner_request.optimization_id,
}
logger.debug(f"GenerationAgent: Refined candidate {refiner_request.optimization_id}")
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"GenerationAgent refinement failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def _repair_candidate(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.models.models import AIServiceCodeRepairRequest
repair_request: AIServiceCodeRepairRequest = payload["repair_request"]
trace_id: str = payload["trace_id"]
repaired_candidate = self.aiservice_client.repair_code(repair_request, trace_id)
result_data = {
"repaired_candidate": repaired_candidate,
"original_optimization_id": repair_request.optimization_id,
}
logger.debug(f"GenerationAgent: Repaired candidate {repair_request.optimization_id}")
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"GenerationAgent repair failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def _adaptive_optimize(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.models.models import AIServiceAdaptiveOptimizeRequest
adaptive_request: AIServiceAdaptiveOptimizeRequest = payload["adaptive_request"]
trace_id: str = payload["trace_id"]
adaptive_candidates = self.aiservice_client.adaptive_optimize(adaptive_request, trace_id)
result_data = {
"adaptive_candidates": adaptive_candidates,
}
logger.debug("GenerationAgent: Generated adaptive optimization candidates")
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"GenerationAgent adaptive optimization failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def create_generation_task(
code_context: Any,
trace_id: str,
is_async: bool = False,
n_candidates: int = 5,
is_numerical: bool | None = None,
language: str = "python",
) -> AgentTask:
return AgentTask.create(
task_type="generate_candidates",
payload={
"code_context": code_context,
"trace_id": trace_id,
"is_async": is_async,
"n_candidates": n_candidates,
"is_numerical": is_numerical,
"language": language,
},
priority=7,
)
def create_refinement_task(refiner_request: Any, trace_id: str) -> AgentTask:
return AgentTask.create(
task_type="refine_candidate",
payload={
"refiner_request": refiner_request,
"trace_id": trace_id,
},
priority=5,
)
def create_repair_task(repair_request: Any, trace_id: str) -> AgentTask:
return AgentTask.create(
task_type="repair_candidate",
payload={
"repair_request": repair_request,
"trace_id": trace_id,
},
priority=5,
)
def create_adaptive_optimization_task(adaptive_request: Any, trace_id: str) -> AgentTask:
return AgentTask.create(
task_type="adaptive_optimize",
payload={
"adaptive_request": adaptive_request,
"trace_id": trace_id,
},
priority=4,
)

View file

@ -0,0 +1,158 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from codeflash.agents.base import AgentResult, AgentTask, BaseAgent, wrap_error, wrap_result
from codeflash.cli_cmds.console import logger
from codeflash.either import Result
if TYPE_CHECKING:
from argparse import Namespace
from codeflash.agents.coordinator import AgentCoordinator
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import BestOptimization
class IntegrationAgent(BaseAgent):
def __init__(self, coordinator: AgentCoordinator | None = None) -> None:
super().__init__(agent_id="integration", coordinator=coordinator)
def process(self, task: AgentTask) -> Result[AgentResult[Any], str]:
task_type = task.task_type
if task_type == "create_pr":
return self._create_pr(task)
if task_type == "apply_optimization":
return self._apply_optimization(task)
if task_type == "generate_explanation":
return self._generate_explanation(task)
return wrap_error(f"Unknown task type: {task_type}", task.task_id, self.agent_id)
def _create_pr(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.result.create_pr import check_create_pr
best_optimization: BestOptimization = payload["best_optimization"]
function_to_optimize: FunctionToOptimize = payload["function_to_optimize"]
args: Namespace = payload["args"]
pr_result = check_create_pr(
best_optimization=best_optimization,
function_to_optimize=function_to_optimize,
args=args,
)
result_data = {
"pr_result": pr_result,
"function_name": function_to_optimize.qualified_name,
}
if pr_result:
logger.info(f"IntegrationAgent: Created PR for {function_to_optimize.qualified_name}")
else:
logger.debug(f"IntegrationAgent: PR creation skipped for {function_to_optimize.qualified_name}")
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"IntegrationAgent PR creation failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def _apply_optimization(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
best_optimization: BestOptimization = payload["best_optimization"]
function_to_optimize: FunctionToOptimize = payload["function_to_optimize"]
read_writable_code = best_optimization.code_context.read_writable_code
files_modified = []
for code_string in read_writable_code.code_strings:
file_path = code_string.file_path
file_path.write_text(code_string.code, encoding="utf-8")
files_modified.append(file_path)
result_data = {
"files_modified": files_modified,
"function_name": function_to_optimize.qualified_name,
}
logger.info(f"IntegrationAgent: Applied optimization to {len(files_modified)} files")
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"IntegrationAgent optimization application failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def _generate_explanation(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.result.explanation import Explanation
best_optimization: BestOptimization = payload["best_optimization"]
function_to_optimize: FunctionToOptimize = payload["function_to_optimize"]
original_runtime: int = payload["original_runtime"]
explanation = Explanation(
best_optimization=best_optimization,
function_to_optimize=function_to_optimize,
original_runtime=original_runtime,
)
result_data = {
"explanation": explanation,
"summary": explanation.summary if hasattr(explanation, "summary") else None,
}
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"IntegrationAgent explanation generation failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def create_pr_task(
best_optimization: Any,
function_to_optimize: Any,
args: Any,
) -> AgentTask:
return AgentTask.create(
task_type="create_pr",
payload={
"best_optimization": best_optimization,
"function_to_optimize": function_to_optimize,
"args": args,
},
priority=3,
)
def create_apply_optimization_task(
best_optimization: Any,
function_to_optimize: Any,
) -> AgentTask:
return AgentTask.create(
task_type="apply_optimization",
payload={
"best_optimization": best_optimization,
"function_to_optimize": function_to_optimize,
},
priority=3,
)
def create_explanation_task(
best_optimization: Any,
function_to_optimize: Any,
original_runtime: int,
) -> AgentTask:
return AgentTask.create(
task_type="generate_explanation",
payload={
"best_optimization": best_optimization,
"function_to_optimize": function_to_optimize,
"original_runtime": original_runtime,
},
priority=2,
)

View file

@ -0,0 +1,207 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from codeflash.agents.base import AgentResult, AgentTask, BaseAgent, wrap_error, wrap_result
from codeflash.cli_cmds.console import logger
from codeflash.either import Result
if TYPE_CHECKING:
from codeflash.agents.coordinator import AgentCoordinator
from codeflash.models.models import ConcurrencyMetrics, OptimizedCandidateResult
class SelectionAgent(BaseAgent):
def __init__(self, coordinator: AgentCoordinator | None = None) -> None:
super().__init__(agent_id="selection", coordinator=coordinator)
def process(self, task: AgentTask) -> Result[AgentResult[Any], str]:
task_type = task.task_type
if task_type == "evaluate_candidate":
return self._evaluate_candidate(task)
if task_type == "select_best":
return self._select_best(task)
if task_type == "rank_candidates":
return self._rank_candidates(task)
return wrap_error(f"Unknown task type: {task_type}", task.task_id, self.agent_id)
def _evaluate_candidate(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.result.critic import coverage_critic, quantity_of_tests_critic, speedup_critic
candidate_result: OptimizedCandidateResult = payload["candidate_result"]
original_runtime: int = payload["original_runtime"]
best_runtime_until_now: int | None = payload.get("best_runtime_until_now")
original_async_throughput: int | None = payload.get("original_async_throughput")
best_throughput_until_now: int | None = payload.get("best_throughput_until_now")
original_concurrency_metrics: ConcurrencyMetrics | None = payload.get("original_concurrency_metrics")
best_concurrency_ratio_until_now: float | None = payload.get("best_concurrency_ratio_until_now")
original_coverage = payload.get("original_coverage")
passes_quantity_check = quantity_of_tests_critic(candidate_result)
passes_speedup_check = speedup_critic(
candidate_result,
original_runtime,
best_runtime_until_now,
original_async_throughput=original_async_throughput,
best_throughput_until_now=best_throughput_until_now,
original_concurrency_metrics=original_concurrency_metrics,
best_concurrency_ratio_until_now=best_concurrency_ratio_until_now,
)
passes_coverage_check = coverage_critic(original_coverage) if original_coverage else True
is_acceptable = passes_quantity_check and passes_speedup_check and passes_coverage_check
result_data = {
"candidate_result": candidate_result,
"is_acceptable": is_acceptable,
"passes_quantity_check": passes_quantity_check,
"passes_speedup_check": passes_speedup_check,
"passes_coverage_check": passes_coverage_check,
}
logger.debug(
f"SelectionAgent: Evaluated candidate, acceptable={is_acceptable} "
f"(qty={passes_quantity_check}, speed={passes_speedup_check}, cov={passes_coverage_check})"
)
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"SelectionAgent evaluation failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def _select_best(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.result.critic import get_acceptance_reason, performance_gain
candidates: list[dict[str, Any]] = payload["candidates"]
original_runtime: int = payload["original_runtime"]
original_async_throughput: int | None = payload.get("original_async_throughput")
original_concurrency_metrics: ConcurrencyMetrics | None = payload.get("original_concurrency_metrics")
if not candidates:
return wrap_result({"best_candidate": None, "reason": "No candidates provided"}, task.task_id, self.agent_id)
acceptable_candidates = [c for c in candidates if c.get("is_acceptable", False)]
if not acceptable_candidates:
return wrap_result({"best_candidate": None, "reason": "No acceptable candidates"}, task.task_id, self.agent_id)
best = min(acceptable_candidates, key=lambda c: c.get("runtime", float("inf")))
best_runtime = best.get("runtime", original_runtime)
perf_gain = performance_gain(original_runtime_ns=original_runtime, optimized_runtime_ns=best_runtime)
speedup = (original_runtime / best_runtime) if best_runtime > 0 else 0.0
acceptance_reason = get_acceptance_reason(
original_runtime_ns=original_runtime,
optimized_runtime_ns=best_runtime,
original_async_throughput=original_async_throughput,
optimized_async_throughput=best.get("async_throughput"),
original_concurrency_metrics=original_concurrency_metrics,
optimized_concurrency_metrics=best.get("concurrency_metrics"),
)
result_data = {
"best_candidate": best,
"performance_gain": perf_gain,
"speedup": speedup,
"acceptance_reason": acceptance_reason.value,
}
logger.info(f"SelectionAgent: Selected best candidate with {speedup:.2f}x speedup")
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"SelectionAgent selection failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def _rank_candidates(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.result.critic import performance_gain
candidates: list[dict[str, Any]] = payload["candidates"]
original_runtime: int = payload["original_runtime"]
ranked = []
for candidate in candidates:
runtime = candidate.get("runtime", float("inf"))
perf = performance_gain(original_runtime_ns=original_runtime, optimized_runtime_ns=runtime)
ranked.append({
**candidate,
"performance_gain": perf,
"speedup": (original_runtime / runtime) if runtime > 0 else 0.0,
})
ranked.sort(key=lambda c: c.get("performance_gain", 0), reverse=True)
result_data = {
"ranked_candidates": ranked,
}
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"SelectionAgent ranking failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def create_evaluation_task(
candidate_result: Any,
original_runtime: int,
best_runtime_until_now: int | None = None,
original_async_throughput: int | None = None,
best_throughput_until_now: int | None = None,
original_concurrency_metrics: Any | None = None,
best_concurrency_ratio_until_now: float | None = None,
original_coverage: Any | None = None,
) -> AgentTask:
return AgentTask.create(
task_type="evaluate_candidate",
payload={
"candidate_result": candidate_result,
"original_runtime": original_runtime,
"best_runtime_until_now": best_runtime_until_now,
"original_async_throughput": original_async_throughput,
"best_throughput_until_now": best_throughput_until_now,
"original_concurrency_metrics": original_concurrency_metrics,
"best_concurrency_ratio_until_now": best_concurrency_ratio_until_now,
"original_coverage": original_coverage,
},
priority=5,
)
def create_selection_task(
candidates: list[dict[str, Any]],
original_runtime: int,
original_async_throughput: int | None = None,
original_concurrency_metrics: Any | None = None,
) -> AgentTask:
return AgentTask.create(
task_type="select_best",
payload={
"candidates": candidates,
"original_runtime": original_runtime,
"original_async_throughput": original_async_throughput,
"original_concurrency_metrics": original_concurrency_metrics,
},
priority=4,
)
def create_ranking_task(
candidates: list[dict[str, Any]],
original_runtime: int,
) -> AgentTask:
return AgentTask.create(
task_type="rank_candidates",
payload={
"candidates": candidates,
"original_runtime": original_runtime,
},
priority=4,
)

View file

@ -0,0 +1,282 @@
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Any
from codeflash.agents.base import AgentResult, AgentTask, BaseAgent, wrap_error, wrap_result
from codeflash.cli_cmds.console import logger
from codeflash.either import Result
if TYPE_CHECKING:
from codeflash.agents.coordinator import AgentCoordinator
from codeflash.models.models import OptimizedCandidate, TestFiles
from codeflash.verification.verification_utils import TestConfig
class VerificationAgent(BaseAgent):
def __init__(self, coordinator: AgentCoordinator | None = None) -> None:
super().__init__(agent_id="verification", coordinator=coordinator)
def process(self, task: AgentTask) -> Result[AgentResult[Any], str]:
task_type = task.task_type
if task_type == "generate_tests":
return self._generate_tests(task)
if task_type == "run_behavioral_tests":
return self._run_behavioral_tests(task)
if task_type == "run_benchmark_tests":
return self._run_benchmark_tests(task)
if task_type == "establish_baseline":
return self._establish_baseline(task)
if task_type == "verify_candidate":
return self._verify_candidate(task)
return wrap_error(f"Unknown task type: {task_type}", task.task_id, self.agent_id)
def _generate_tests(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.verification.verifier import generate_tests
code_context = payload["code_context"]
function_to_optimize = payload["function_to_optimize"]
test_cfg: TestConfig = payload["test_cfg"]
project_root: Path = payload["project_root"]
trace_id: str = payload["trace_id"]
generated_tests = generate_tests(
code_context=code_context,
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
project_root_path=project_root,
trace_id=trace_id,
)
result_data = {
"generated_tests": generated_tests,
}
logger.debug(f"VerificationAgent: Generated tests for {function_to_optimize.qualified_name}")
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"VerificationAgent test generation failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def _run_behavioral_tests(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.verification.test_runner import run_behavioral_tests
test_files: TestFiles = payload["test_files"]
test_cfg: TestConfig = payload["test_cfg"]
project_root: Path = payload["project_root"]
test_timeout: int = payload.get("test_timeout", 60)
test_output = run_behavioral_tests(
test_files=test_files,
test_cfg=test_cfg,
project_root_path=project_root,
test_timeout=test_timeout,
)
result_data = {
"test_output": test_output,
}
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"VerificationAgent behavioral tests failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def _run_benchmark_tests(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.verification.test_runner import run_benchmarking_tests
test_files: TestFiles = payload["test_files"]
test_cfg: TestConfig = payload["test_cfg"]
project_root: Path = payload["project_root"]
test_timeout: int = payload.get("test_timeout", 300)
benchmark_output = run_benchmarking_tests(
test_files=test_files,
test_cfg=test_cfg,
project_root_path=project_root,
test_timeout=test_timeout,
)
result_data = {
"benchmark_output": benchmark_output,
}
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"VerificationAgent benchmark tests failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def _establish_baseline(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.verification.parse_test_output import parse_test_results
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests
test_files: TestFiles = payload["test_files"]
test_cfg: TestConfig = payload["test_cfg"]
project_root: Path = payload["project_root"]
function_qualified_name: str = payload["function_qualified_name"]
behavior_output = run_behavioral_tests(
test_files=test_files,
test_cfg=test_cfg,
project_root_path=project_root,
)
behavior_results = parse_test_results(
test_output=behavior_output,
function_qualified_name=function_qualified_name,
)
benchmark_output = run_benchmarking_tests(
test_files=test_files,
test_cfg=test_cfg,
project_root_path=project_root,
)
benchmark_results = parse_test_results(
test_output=benchmark_output,
function_qualified_name=function_qualified_name,
)
result_data = {
"behavior_results": behavior_results,
"benchmark_results": benchmark_results,
"behavior_output": behavior_output,
"benchmark_output": benchmark_output,
}
logger.debug(f"VerificationAgent: Established baseline for {function_qualified_name}")
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"VerificationAgent baseline establishment failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def _verify_candidate(self, task: AgentTask) -> Result[AgentResult[Any], str]:
payload = task.payload
try:
from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.parse_test_output import parse_test_results
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests
candidate: OptimizedCandidate = payload["candidate"]
test_files: TestFiles = payload["test_files"]
test_cfg: TestConfig = payload["test_cfg"]
project_root: Path = payload["project_root"]
function_qualified_name: str = payload["function_qualified_name"]
original_behavior_results = payload["original_behavior_results"]
behavior_output = run_behavioral_tests(
test_files=test_files,
test_cfg=test_cfg,
project_root_path=project_root,
)
behavior_results = parse_test_results(
test_output=behavior_output,
function_qualified_name=function_qualified_name,
)
is_equivalent, diffs = compare_test_results(
original_results=original_behavior_results,
optimized_results=behavior_results,
)
benchmark_results = None
benchmark_output = None
if is_equivalent:
benchmark_output = run_benchmarking_tests(
test_files=test_files,
test_cfg=test_cfg,
project_root_path=project_root,
)
benchmark_results = parse_test_results(
test_output=benchmark_output,
function_qualified_name=function_qualified_name,
)
result_data = {
"candidate": candidate,
"is_equivalent": is_equivalent,
"diffs": diffs,
"behavior_results": behavior_results,
"benchmark_results": benchmark_results,
}
logger.debug(f"VerificationAgent: Verified candidate {candidate.optimization_id}, equivalent={is_equivalent}")
return wrap_result(result_data, task.task_id, self.agent_id)
except Exception as e:
logger.exception(f"VerificationAgent candidate verification failed: {e}")
return wrap_error(str(e), task.task_id, self.agent_id)
def create_test_generation_task(
code_context: Any,
function_to_optimize: Any,
test_cfg: Any,
project_root: Path,
trace_id: str,
) -> AgentTask:
return AgentTask.create(
task_type="generate_tests",
payload={
"code_context": code_context,
"function_to_optimize": function_to_optimize,
"test_cfg": test_cfg,
"project_root": project_root,
"trace_id": trace_id,
},
priority=8,
)
def create_baseline_task(
test_files: Any,
test_cfg: Any,
project_root: Path,
function_qualified_name: str,
) -> AgentTask:
return AgentTask.create(
task_type="establish_baseline",
payload={
"test_files": test_files,
"test_cfg": test_cfg,
"project_root": project_root,
"function_qualified_name": function_qualified_name,
},
priority=7,
)
def create_verification_task(
candidate: Any,
test_files: Any,
test_cfg: Any,
project_root: Path,
function_qualified_name: str,
original_behavior_results: Any,
) -> AgentTask:
return AgentTask.create(
task_type="verify_candidate",
payload={
"candidate": candidate,
"test_files": test_files,
"test_cfg": test_cfg,
"project_root": project_root,
"function_qualified_name": function_qualified_name,
"original_behavior_results": original_behavior_results,
},
priority=6,
)

View file

@ -130,6 +130,11 @@ def parse_args() -> Namespace:
"--reset-config", action="store_true", help="Remove codeflash configuration from project config file."
)
parser.add_argument("-y", "--yes", action="store_true", help="Skip confirmation prompts (useful for CI/scripts).")
parser.add_argument(
"--agentic",
action="store_true",
help="Use the agentic multi-agent optimization pipeline instead of the default batch pipeline.",
)
args, unknown_args = parser.parse_known_args()
sys.argv[:] = [sys.argv[0], *unknown_args]

View file

@ -59,9 +59,14 @@ def main() -> None:
init_sentry(enabled=not args.disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
from codeflash.optimization import optimizer
if getattr(args, "agentic", False):
from codeflash.optimization.agentic_optimizer import run_agentic_with_args
optimizer.run_with_args(args)
run_agentic_with_args(args)
else:
from codeflash.optimization import optimizer
optimizer.run_with_args(args)
def _handle_config_loading(args: Namespace) -> Namespace | None:

View file

@ -0,0 +1,203 @@
from __future__ import annotations
import tempfile
import time
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.agents.coordinator import AgentCoordinator
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import cleanup_paths
from codeflash.either import is_successful
from codeflash.languages import is_javascript, set_current_language
from codeflash.state.store import StateStore
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.verification_utils import TestConfig
if TYPE_CHECKING:
from argparse import Namespace
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
class AgenticOptimizer:
def __init__(self, args: Namespace) -> None:
self.args = args
self.state_store = StateStore()
self.coordinator = AgentCoordinator(state_store=self.state_store)
self.test_cfg = TestConfig(
tests_root=args.tests_root,
tests_project_rootdir=args.test_project_root,
project_root_path=args.project_root,
pytest_cmd=args.pytest_cmd if hasattr(args, "pytest_cmd") and args.pytest_cmd else "pytest",
benchmark_tests_root=args.benchmarks_root if "benchmark" in args and "benchmarks_root" in args else None,
)
def run(self) -> None:
ph("cli-agentic-optimize-run-start")
logger.info("Running agentic optimizer.")
console.rule()
if not env_utils.ensure_codeflash_api_key():
return
self.coordinator.register_agents()
file_to_funcs, num_functions, trace_file_path = self._get_optimizable_functions()
if file_to_funcs:
for file_path, funcs in file_to_funcs.items():
if funcs and funcs[0].language:
set_current_language(funcs[0].language)
self.test_cfg.set_language(funcs[0].language)
if is_javascript():
self.test_cfg.js_project_root = self._find_js_project_root(file_path)
break
if num_functions == 0:
logger.info("No functions found to optimize. Exiting…")
return
function_to_tests = self._discover_tests(file_to_funcs)
self.test_cfg.concolic_test_root_dir = Path(
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
)
try:
optimizations_found = 0
globally_ranked = self._rank_functions(file_to_funcs, trace_file_path)
for i, (file_path, function_to_optimize) in enumerate(globally_ranked):
logger.info(
f"Optimizing function {i + 1} of {len(globally_ranked)}: "
f"{function_to_optimize.qualified_name} (in {file_path.name})"
)
console.rule()
result = self.coordinator.run_optimization_pipeline(
function_to_optimize=function_to_optimize,
test_cfg=self.test_cfg,
args=self.args,
function_to_tests=function_to_tests,
)
if is_successful(result):
optimizations_found += 1
logger.info(f"Successfully optimized {function_to_optimize.qualified_name}")
else:
logger.warning(f"Failed to optimize {function_to_optimize.qualified_name}: {result.failure()}")
console.rule()
ph("cli-agentic-optimize-run-finished", {"optimizations_found": optimizations_found})
if optimizations_found == 0:
logger.info("No optimizations found.")
else:
logger.info(f"Found {optimizations_found} optimization(s).")
finally:
cleanup_paths([self.test_cfg.concolic_test_root_dir])
def _get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]], int, Path | None]:
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
return get_functions_to_optimize(
optimize_all=self.args.all,
replay_test=self.args.replay_test,
file=self.args.file,
only_get_this_function=self.args.function,
test_cfg=self.test_cfg,
ignore_paths=self.args.ignore_paths,
project_root=self.args.project_root,
module_root=self.args.module_root,
previous_checkpoint_functions=self.args.previous_checkpoint_functions,
)
def _discover_tests(
self, file_to_funcs: dict[Path, list[FunctionToOptimize]]
) -> dict[str, set]:
from codeflash.discovery.discover_unit_tests import discover_unit_tests
console.rule()
start_time = time.time()
logger.info("Discovering existing function tests...")
function_to_tests, num_tests, num_replay_tests = discover_unit_tests(
self.test_cfg, file_to_funcs_to_optimize=file_to_funcs
)
console.rule()
logger.info(
f"Discovered {num_tests} existing tests and {num_replay_tests} replay tests "
f"in {(time.time() - start_time):.1f}s"
)
console.rule()
return function_to_tests
def _rank_functions(
self,
file_to_funcs: dict[Path, list[FunctionToOptimize]],
trace_file_path: Path | None,
) -> list[tuple[Path, FunctionToOptimize]]:
all_functions: list[tuple[Path, FunctionToOptimize]] = []
for file_path, functions in file_to_funcs.items():
all_functions.extend((file_path, func) for func in functions)
if not trace_file_path or not trace_file_path.exists():
return all_functions
try:
from codeflash.benchmarking.function_ranker import FunctionRanker
logger.info("Ranking functions by performance impact...")
ranker = FunctionRanker(trace_file_path)
functions_only = [func for _, func in all_functions]
ranked_functions = ranker.rank_functions(functions_only)
func_to_file_map = {}
for file_path, func in all_functions:
key = (func.file_path, func.qualified_name, func.starting_line)
func_to_file_map[key] = file_path
globally_ranked = []
for func in ranked_functions:
key = (func.file_path, func.qualified_name, func.starting_line)
file_path = func_to_file_map.get(key)
if file_path:
globally_ranked.append((file_path, func))
logger.info(f"Ranked {len(ranked_functions)} functions by addressable time")
return globally_ranked
except Exception as e:
logger.warning(f"Could not perform ranking: {e}")
return all_functions
def _find_js_project_root(self, file_path: Path) -> Path | None:
current = file_path.parent if file_path.is_file() else file_path
while current != current.parent:
if (
(current / "package.json").exists()
or (current / "jest.config.js").exists()
or (current / "jest.config.ts").exists()
or (current / "tsconfig.json").exists()
):
return current
current = current.parent
return None
def run_agentic_with_args(args: Namespace) -> None:
optimizer = None
try:
optimizer = AgenticOptimizer(args)
optimizer.run()
except KeyboardInterrupt:
logger.warning("Keyboard interrupt received. Cleaning up and exiting...")
raise SystemExit from None

View file

@ -0,0 +1,13 @@
from __future__ import annotations
from codeflash.state.history import OptimizationHistory
from codeflash.state.models import AgentStateSnapshot, OptimizationAttempt, OptimizationStatus
from codeflash.state.store import StateStore
__all__ = [
"AgentStateSnapshot",
"OptimizationAttempt",
"OptimizationHistory",
"OptimizationStatus",
"StateStore",
]

View file

@ -0,0 +1,67 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from codeflash.state.models import OptimizationAttempt, OptimizationStatus
if TYPE_CHECKING:
from codeflash.state.store import StateStore
class OptimizationHistory:
def __init__(self, store: StateStore) -> None:
self.store = store
def record_attempt(self, attempt: OptimizationAttempt) -> None:
self.store.persist_optimization_attempt(attempt)
def get_function_attempts(self, qualified_name: str, limit: int = 100) -> list[OptimizationAttempt]:
return self.store.get_function_history(qualified_name, limit=limit)
def get_successful_optimizations(
self, qualified_name: str | None = None, limit: int = 100
) -> list[OptimizationAttempt]:
if qualified_name:
attempts = self.store.get_function_history(qualified_name, limit=limit)
return [a for a in attempts if a.status == OptimizationStatus.COMPLETED]
return self.store.get_recent_attempts(status=OptimizationStatus.COMPLETED, limit=limit)
def get_failed_optimizations(self, qualified_name: str | None = None, limit: int = 100) -> list[OptimizationAttempt]:
if qualified_name:
attempts = self.store.get_function_history(qualified_name, limit=limit)
return [a for a in attempts if a.status == OptimizationStatus.FAILED]
return self.store.get_recent_attempts(status=OptimizationStatus.FAILED, limit=limit)
def should_skip_function(self, qualified_name: str, code_hash: str | None = None) -> tuple[bool, str | None]:
if code_hash and self.store.was_function_recently_optimized(qualified_name, code_hash=code_hash):
return True, "Function with same code hash was recently optimized"
recent_attempts = self.store.get_function_history(qualified_name, limit=5)
completed_count = sum(1 for a in recent_attempts if a.status == OptimizationStatus.COMPLETED)
if completed_count >= 3:
return True, "Function has been successfully optimized multiple times recently"
failed_count = sum(1 for a in recent_attempts if a.status == OptimizationStatus.FAILED)
if failed_count >= 3:
return True, "Function has failed optimization multiple times recently"
return False, None
def get_best_speedup(self, qualified_name: str) -> float | None:
attempts = self.store.get_function_history(qualified_name)
speedups = [a.speedup for a in attempts if a.speedup is not None and a.status == OptimizationStatus.COMPLETED]
return max(speedups) if speedups else None
def get_statistics(self) -> dict[str, int]:
recent = self.store.get_recent_attempts(limit=1000)
return {
"total_attempts": len(recent),
"completed": sum(1 for a in recent if a.status == OptimizationStatus.COMPLETED),
"failed": sum(1 for a in recent if a.status == OptimizationStatus.FAILED),
"skipped": sum(1 for a in recent if a.status == OptimizationStatus.SKIPPED),
"in_progress": sum(1 for a in recent if a.status == OptimizationStatus.IN_PROGRESS),
}
def cleanup(self, days: int = 30) -> int:
return self.store.cleanup_old_records(days=days)

172
codeflash/state/models.py Normal file
View file

@ -0,0 +1,172 @@
from __future__ import annotations
import time
from enum import Enum
from pathlib import Path
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
class OptimizationStatus(str, Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
SKIPPED = "skipped"
class OptimizationAttempt(BaseModel):
model_config = ConfigDict(frozen=True)
attempt_id: str
function_qualified_name: str
file_path: str
status: OptimizationStatus
started_at: float
completed_at: float | None = None
speedup: float | None = None
original_runtime_ns: int | None = None
optimized_runtime_ns: int | None = None
error_message: str | None = None
pr_url: str | None = None
code_hash: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@classmethod
def create(
cls,
attempt_id: str,
function_qualified_name: str,
file_path: str | Path,
code_hash: str | None = None,
metadata: dict[str, Any] | None = None,
) -> OptimizationAttempt:
return cls(
attempt_id=attempt_id,
function_qualified_name=function_qualified_name,
file_path=str(file_path),
status=OptimizationStatus.PENDING,
started_at=time.time(),
code_hash=code_hash,
metadata=metadata or {},
)
def mark_in_progress(self) -> OptimizationAttempt:
return OptimizationAttempt(
attempt_id=self.attempt_id,
function_qualified_name=self.function_qualified_name,
file_path=self.file_path,
status=OptimizationStatus.IN_PROGRESS,
started_at=self.started_at,
code_hash=self.code_hash,
metadata=self.metadata,
)
def mark_completed(
self,
speedup: float,
original_runtime_ns: int,
optimized_runtime_ns: int,
pr_url: str | None = None,
) -> OptimizationAttempt:
return OptimizationAttempt(
attempt_id=self.attempt_id,
function_qualified_name=self.function_qualified_name,
file_path=self.file_path,
status=OptimizationStatus.COMPLETED,
started_at=self.started_at,
completed_at=time.time(),
speedup=speedup,
original_runtime_ns=original_runtime_ns,
optimized_runtime_ns=optimized_runtime_ns,
pr_url=pr_url,
code_hash=self.code_hash,
metadata=self.metadata,
)
def mark_failed(self, error_message: str) -> OptimizationAttempt:
return OptimizationAttempt(
attempt_id=self.attempt_id,
function_qualified_name=self.function_qualified_name,
file_path=self.file_path,
status=OptimizationStatus.FAILED,
started_at=self.started_at,
completed_at=time.time(),
error_message=error_message,
code_hash=self.code_hash,
metadata=self.metadata,
)
def mark_skipped(self, reason: str) -> OptimizationAttempt:
return OptimizationAttempt(
attempt_id=self.attempt_id,
function_qualified_name=self.function_qualified_name,
file_path=self.file_path,
status=OptimizationStatus.SKIPPED,
started_at=self.started_at,
completed_at=time.time(),
error_message=reason,
code_hash=self.code_hash,
metadata=self.metadata,
)
class AgentStateSnapshot(BaseModel):
model_config = ConfigDict(frozen=True)
agent_id: str
agent_type: str
state: str
current_task_id: str | None = None
last_updated: float = Field(default_factory=time.time)
context: dict[str, Any] = Field(default_factory=dict)
@classmethod
def create(
cls,
agent_id: str,
agent_type: str,
state: str,
current_task_id: str | None = None,
context: dict[str, Any] | None = None,
) -> AgentStateSnapshot:
return cls(
agent_id=agent_id,
agent_type=agent_type,
state=state,
current_task_id=current_task_id,
context=context or {},
)
class PipelineState(BaseModel):
model_config = ConfigDict(frozen=True)
pipeline_id: str
status: str
current_stage: str | None = None
started_at: float
completed_at: float | None = None
function_qualified_name: str
file_path: str
stages_completed: list[str] = Field(default_factory=list)
error_message: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@classmethod
def create(
cls,
pipeline_id: str,
function_qualified_name: str,
file_path: str | Path,
metadata: dict[str, Any] | None = None,
) -> PipelineState:
return cls(
pipeline_id=pipeline_id,
status="pending",
started_at=time.time(),
function_qualified_name=function_qualified_name,
file_path=str(file_path),
metadata=metadata or {},
)

310
codeflash/state/store.py Normal file
View file

@ -0,0 +1,310 @@
from __future__ import annotations
import json
import sqlite3
import time
from collections.abc import Iterator
from contextlib import contextmanager
from pathlib import Path
from typing import Any
from codeflash.state.models import AgentStateSnapshot, OptimizationAttempt, OptimizationStatus, PipelineState
class StateStore:
def __init__(self, storage_path: Path | None = None) -> None:
if storage_path is None:
from codeflash.code_utils.compat import codeflash_temp_dir
storage_path = codeflash_temp_dir
self.storage_path = storage_path
self.storage_path.mkdir(parents=True, exist_ok=True)
self.db_path = self.storage_path / "codeflash_agent.db"
self._init_database()
def _init_database(self) -> None:
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS optimization_attempts (
attempt_id TEXT PRIMARY KEY,
function_qualified_name TEXT NOT NULL,
file_path TEXT NOT NULL,
status TEXT NOT NULL,
started_at REAL NOT NULL,
completed_at REAL,
speedup REAL,
original_runtime_ns INTEGER,
optimized_runtime_ns INTEGER,
error_message TEXT,
pr_url TEXT,
code_hash TEXT,
metadata TEXT
)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_function_name
ON optimization_attempts(function_qualified_name)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_status
ON optimization_attempts(status)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS agent_states (
agent_id TEXT PRIMARY KEY,
agent_type TEXT NOT NULL,
state TEXT NOT NULL,
current_task_id TEXT,
last_updated REAL NOT NULL,
context TEXT
)
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS pipeline_states (
pipeline_id TEXT PRIMARY KEY,
status TEXT NOT NULL,
current_stage TEXT,
started_at REAL NOT NULL,
completed_at REAL,
function_qualified_name TEXT NOT NULL,
file_path TEXT NOT NULL,
stages_completed TEXT,
error_message TEXT,
metadata TEXT
)
""")
conn.commit()
@contextmanager
def _get_connection(self) -> Iterator[sqlite3.Connection]:
conn = sqlite3.connect(str(self.db_path), timeout=30.0)
conn.row_factory = sqlite3.Row
try:
yield conn
finally:
conn.close()
def persist_optimization_attempt(self, attempt: OptimizationAttempt) -> None:
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO optimization_attempts
(attempt_id, function_qualified_name, file_path, status, started_at,
completed_at, speedup, original_runtime_ns, optimized_runtime_ns,
error_message, pr_url, code_hash, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
attempt.attempt_id,
attempt.function_qualified_name,
attempt.file_path,
attempt.status.value,
attempt.started_at,
attempt.completed_at,
attempt.speedup,
attempt.original_runtime_ns,
attempt.optimized_runtime_ns,
attempt.error_message,
attempt.pr_url,
attempt.code_hash,
json.dumps(attempt.metadata),
),
)
conn.commit()
def get_optimization_attempt(self, attempt_id: str) -> OptimizationAttempt | None:
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM optimization_attempts WHERE attempt_id = ?", (attempt_id,))
row = cursor.fetchone()
if row is None:
return None
return self._row_to_optimization_attempt(row)
def get_function_history(self, qualified_name: str, limit: int = 100) -> list[OptimizationAttempt]:
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT * FROM optimization_attempts
WHERE function_qualified_name = ?
ORDER BY started_at DESC
LIMIT ?
""",
(qualified_name, limit),
)
return [self._row_to_optimization_attempt(row) for row in cursor.fetchall()]
def get_recent_attempts(
self, since_timestamp: float | None = None, status: OptimizationStatus | None = None, limit: int = 100
) -> list[OptimizationAttempt]:
with self._get_connection() as conn:
cursor = conn.cursor()
query = "SELECT * FROM optimization_attempts WHERE 1=1"
params: list[Any] = []
if since_timestamp is not None:
query += " AND started_at >= ?"
params.append(since_timestamp)
if status is not None:
query += " AND status = ?"
params.append(status.value)
query += " ORDER BY started_at DESC LIMIT ?"
params.append(limit)
cursor.execute(query, params)
return [self._row_to_optimization_attempt(row) for row in cursor.fetchall()]
def was_function_recently_optimized(
self, qualified_name: str, code_hash: str | None = None, within_days: int = 7
) -> bool:
cutoff_time = time.time() - (within_days * 24 * 60 * 60)
with self._get_connection() as conn:
cursor = conn.cursor()
if code_hash:
cursor.execute(
"""
SELECT COUNT(*) FROM optimization_attempts
WHERE function_qualified_name = ?
AND code_hash = ?
AND status = ?
AND started_at >= ?
""",
(qualified_name, code_hash, OptimizationStatus.COMPLETED.value, cutoff_time),
)
else:
cursor.execute(
"""
SELECT COUNT(*) FROM optimization_attempts
WHERE function_qualified_name = ?
AND status = ?
AND started_at >= ?
""",
(qualified_name, OptimizationStatus.COMPLETED.value, cutoff_time),
)
count = cursor.fetchone()[0]
return count > 0
def persist_agent_state(self, snapshot: AgentStateSnapshot) -> None:
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO agent_states
(agent_id, agent_type, state, current_task_id, last_updated, context)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
snapshot.agent_id,
snapshot.agent_type,
snapshot.state,
snapshot.current_task_id,
snapshot.last_updated,
json.dumps(snapshot.context),
),
)
conn.commit()
def get_agent_state(self, agent_id: str) -> AgentStateSnapshot | None:
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM agent_states WHERE agent_id = ?", (agent_id,))
row = cursor.fetchone()
if row is None:
return None
return AgentStateSnapshot(
agent_id=row["agent_id"],
agent_type=row["agent_type"],
state=row["state"],
current_task_id=row["current_task_id"],
last_updated=row["last_updated"],
context=json.loads(row["context"]) if row["context"] else {},
)
def persist_pipeline_state(self, state: PipelineState) -> None:
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR REPLACE INTO pipeline_states
(pipeline_id, status, current_stage, started_at, completed_at,
function_qualified_name, file_path, stages_completed, error_message, metadata)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
state.pipeline_id,
state.status,
state.current_stage,
state.started_at,
state.completed_at,
state.function_qualified_name,
state.file_path,
json.dumps(state.stages_completed),
state.error_message,
json.dumps(state.metadata),
),
)
conn.commit()
def get_pipeline_state(self, pipeline_id: str) -> PipelineState | None:
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT * FROM pipeline_states WHERE pipeline_id = ?", (pipeline_id,))
row = cursor.fetchone()
if row is None:
return None
return PipelineState(
pipeline_id=row["pipeline_id"],
status=row["status"],
current_stage=row["current_stage"],
started_at=row["started_at"],
completed_at=row["completed_at"],
function_qualified_name=row["function_qualified_name"],
file_path=row["file_path"],
stages_completed=json.loads(row["stages_completed"]) if row["stages_completed"] else [],
error_message=row["error_message"],
metadata=json.loads(row["metadata"]) if row["metadata"] else {},
)
def cleanup_old_records(self, days: int = 30) -> int:
cutoff_time = time.time() - (days * 24 * 60 * 60)
with self._get_connection() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM optimization_attempts WHERE started_at < ?", (cutoff_time,))
deleted_attempts = cursor.rowcount
cursor.execute("DELETE FROM pipeline_states WHERE started_at < ?", (cutoff_time,))
deleted_pipelines = cursor.rowcount
conn.commit()
return deleted_attempts + deleted_pipelines
def _row_to_optimization_attempt(self, row: sqlite3.Row) -> OptimizationAttempt:
return OptimizationAttempt(
attempt_id=row["attempt_id"],
function_qualified_name=row["function_qualified_name"],
file_path=row["file_path"],
status=OptimizationStatus(row["status"]),
started_at=row["started_at"],
completed_at=row["completed_at"],
speedup=row["speedup"],
original_runtime_ns=row["original_runtime_ns"],
optimized_runtime_ns=row["optimized_runtime_ns"],
error_message=row["error_message"],
pr_url=row["pr_url"],
code_hash=row["code_hash"],
metadata=json.loads(row["metadata"]) if row["metadata"] else {},
)

View file

View file

@ -0,0 +1,181 @@
from __future__ import annotations
import pytest
from codeflash.agents.base import (
AgentMessage,
AgentResult,
AgentState,
AgentTask,
BaseAgent,
wrap_error,
wrap_result,
)
from codeflash.either import Result
class MockAgent(BaseAgent):
def __init__(self, agent_id: str = "mock", should_fail: bool = False) -> None:
super().__init__(agent_id=agent_id, coordinator=None)
self.should_fail = should_fail
self.process_called = False
def process(self, task: AgentTask) -> Result:
self.process_called = True
if self.should_fail:
return wrap_error("Mock failure", task.task_id, self.agent_id)
return wrap_result({"mock_data": "test"}, task.task_id, self.agent_id)
class TestAgentState:
def test_agent_states(self) -> None:
assert AgentState.IDLE.value == "idle"
assert AgentState.RUNNING.value == "running"
assert AgentState.WAITING.value == "waiting"
assert AgentState.COMPLETED.value == "completed"
assert AgentState.FAILED.value == "failed"
class TestAgentTask:
def test_create_task(self) -> None:
task = AgentTask.create(
task_type="test_task",
payload={"key": "value"},
priority=5,
)
assert task.task_type == "test_task"
assert task.payload == {"key": "value"}
assert task.priority == 5
assert task.task_id is not None
assert task.parent_task_id is None
def test_task_with_parent(self) -> None:
task = AgentTask.create(
task_type="child_task",
payload={},
parent_task_id="parent-123",
)
assert task.parent_task_id == "parent-123"
def test_task_comparison(self) -> None:
high_priority = AgentTask.create("high", {}, priority=10)
low_priority = AgentTask.create("low", {}, priority=1)
assert high_priority < low_priority
class TestAgentMessage:
def test_create_message(self) -> None:
message = AgentMessage.create(
sender_id="agent-1",
message_type="notification",
payload={"status": "ready"},
)
assert message.sender_id == "agent-1"
assert message.message_type == "notification"
assert message.payload == {"status": "ready"}
assert message.message_id is not None
assert message.timestamp > 0
class TestAgentResult:
def test_success_result(self) -> None:
result = AgentResult.success_result(
task_id="task-1",
agent_id="agent-1",
data={"result": "success"},
)
assert result.success is True
assert result.data == {"result": "success"}
assert result.error is None
def test_failure_result(self) -> None:
result = AgentResult.failure_result(
task_id="task-1",
agent_id="agent-1",
error="Something went wrong",
)
assert result.success is False
assert result.data is None
assert result.error == "Something went wrong"
def test_result_with_metadata(self) -> None:
result = AgentResult.success_result(
task_id="task-1",
agent_id="agent-1",
data={},
metadata={"duration_ms": 100},
)
assert result.metadata == {"duration_ms": 100}
class TestBaseAgent:
def test_agent_initialization(self) -> None:
agent = MockAgent("test-agent")
assert agent.agent_id == "test-agent"
assert agent.state == AgentState.IDLE
assert agent.coordinator is None
assert agent.current_task is None
def test_agent_execute_success(self) -> None:
agent = MockAgent("test-agent")
task = AgentTask.create("test", {"data": "value"})
result = agent.execute(task)
assert result.is_successful()
assert agent.state == AgentState.COMPLETED
assert agent.process_called
def test_agent_execute_failure(self) -> None:
agent = MockAgent("test-agent", should_fail=True)
task = AgentTask.create("test", {})
result = agent.execute(task)
assert result.is_failure()
assert agent.state == AgentState.FAILED
def test_agent_messaging(self) -> None:
agent = MockAgent()
assert not agent.has_pending_messages()
message = AgentMessage.create("sender", "test", {})
agent.receive_message(message)
assert agent.has_pending_messages()
received = agent.get_next_message()
assert received == message
assert not agent.has_pending_messages()
def test_agent_reset(self) -> None:
agent = MockAgent()
agent.state = AgentState.RUNNING
agent.receive_message(AgentMessage.create("sender", "test", {}))
agent.reset()
assert agent.state == AgentState.IDLE
assert not agent.has_pending_messages()
def test_get_state_snapshot(self) -> None:
agent = MockAgent("snapshot-agent")
snapshot = agent.get_state_snapshot()
assert snapshot["agent_id"] == "snapshot-agent"
assert snapshot["state"] == "idle"
assert snapshot["current_task"] is None
class TestWrapFunctions:
def test_wrap_result(self) -> None:
result = wrap_result({"data": "test"}, "task-1", "agent-1")
assert result.is_successful()
agent_result = result.unwrap()
assert agent_result.data == {"data": "test"}
def test_wrap_error(self) -> None:
result = wrap_error("error message", "task-1", "agent-1")
assert result.is_failure()
assert result.failure() == "error message"

View file

@ -0,0 +1,112 @@
from __future__ import annotations
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from codeflash.agents.base import AgentResult, AgentTask
from codeflash.agents.coordinator import AgentCoordinator
from codeflash.either import Failure, Success
from codeflash.state.store import StateStore
class TestAgentCoordinator:
@pytest.fixture
def coordinator(self) -> AgentCoordinator:
with tempfile.TemporaryDirectory() as tmpdir:
store = StateStore(Path(tmpdir))
coord = AgentCoordinator(state_store=store)
yield coord
def test_coordinator_initialization(self, coordinator: AgentCoordinator) -> None:
assert coordinator.agents == {}
assert coordinator.state_store is not None
assert not coordinator._initialized
def test_register_agents(self, coordinator: AgentCoordinator) -> None:
coordinator.register_agents()
assert "discovery" in coordinator.agents
assert "analysis" in coordinator.agents
assert "generation" in coordinator.agents
assert "verification" in coordinator.agents
assert "selection" in coordinator.agents
assert "integration" in coordinator.agents
assert coordinator._initialized
def test_get_agent(self, coordinator: AgentCoordinator) -> None:
coordinator.register_agents()
agent = coordinator.get_agent("discovery")
assert agent is not None
assert agent.agent_id == "discovery"
missing = coordinator.get_agent("nonexistent")
assert missing is None
def test_submit_task(self, coordinator: AgentCoordinator) -> None:
task = AgentTask.create("test", {"data": "value"})
coordinator.submit_task(task)
assert not coordinator.task_queue.empty()
def test_execute_task_agent_not_found(self, coordinator: AgentCoordinator) -> None:
task = AgentTask.create("test", {})
result = coordinator.execute_task("nonexistent", task)
assert result.is_failure()
assert "not found" in result.failure()
def test_persist_agent_state(self, coordinator: AgentCoordinator) -> None:
coordinator.register_agents()
state_snapshot = {
"state": "running",
"current_task": "task-123",
}
coordinator.persist_agent_state("discovery", state_snapshot)
retrieved = coordinator.state_store.get_agent_state("discovery")
assert retrieved is not None
assert retrieved.state == "running"
def test_notify_failure_logs_warning(self, coordinator: AgentCoordinator) -> None:
with patch("codeflash.agents.coordinator.logger") as mock_logger:
coordinator.notify_failure("test-agent", "Test error message")
mock_logger.warning.assert_called_once()
def test_reset_all_agents(self, coordinator: AgentCoordinator) -> None:
coordinator.register_agents()
task = AgentTask.create("test", {})
coordinator.submit_task(task)
coordinator.reset_all_agents()
assert coordinator.task_queue.empty()
for agent in coordinator.agents.values():
from codeflash.agents.base import AgentState
assert agent.state == AgentState.IDLE
class TestCoordinatorExecution:
@pytest.fixture
def coordinator_with_agents(self) -> AgentCoordinator:
with tempfile.TemporaryDirectory() as tmpdir:
store = StateStore(Path(tmpdir))
coord = AgentCoordinator(state_store=store)
coord.register_agents()
yield coord
def test_execute_task_on_registered_agent(self, coordinator_with_agents: AgentCoordinator) -> None:
from codeflash.agents.analysis_agent import create_numerical_check_task
task = create_numerical_check_task("def foo(): return 1")
result = coordinator_with_agents.execute_task("analysis", task)
assert result.is_successful()
agent_result = result.unwrap()
assert agent_result.success
assert "is_numerical" in agent_result.data

View file

@ -0,0 +1,220 @@
from __future__ import annotations
import tempfile
import time
from pathlib import Path
import pytest
from codeflash.state.history import OptimizationHistory
from codeflash.state.models import AgentStateSnapshot, OptimizationAttempt, OptimizationStatus, PipelineState
from codeflash.state.store import StateStore
class TestOptimizationStatus:
def test_status_values(self) -> None:
assert OptimizationStatus.PENDING.value == "pending"
assert OptimizationStatus.IN_PROGRESS.value == "in_progress"
assert OptimizationStatus.COMPLETED.value == "completed"
assert OptimizationStatus.FAILED.value == "failed"
assert OptimizationStatus.SKIPPED.value == "skipped"
class TestOptimizationAttempt:
def test_create_attempt(self) -> None:
attempt = OptimizationAttempt.create(
attempt_id="attempt-1",
function_qualified_name="module.function",
file_path="/path/to/file.py",
code_hash="abc123",
)
assert attempt.attempt_id == "attempt-1"
assert attempt.function_qualified_name == "module.function"
assert attempt.file_path == "/path/to/file.py"
assert attempt.status == OptimizationStatus.PENDING
assert attempt.code_hash == "abc123"
def test_mark_in_progress(self) -> None:
attempt = OptimizationAttempt.create("id", "func", "/file.py")
in_progress = attempt.mark_in_progress()
assert in_progress.status == OptimizationStatus.IN_PROGRESS
assert in_progress.attempt_id == attempt.attempt_id
def test_mark_completed(self) -> None:
attempt = OptimizationAttempt.create("id", "func", "/file.py")
completed = attempt.mark_completed(
speedup=2.5,
original_runtime_ns=1000000,
optimized_runtime_ns=400000,
pr_url="https://github.com/test/pr/1",
)
assert completed.status == OptimizationStatus.COMPLETED
assert completed.speedup == 2.5
assert completed.original_runtime_ns == 1000000
assert completed.optimized_runtime_ns == 400000
assert completed.pr_url == "https://github.com/test/pr/1"
assert completed.completed_at is not None
def test_mark_failed(self) -> None:
attempt = OptimizationAttempt.create("id", "func", "/file.py")
failed = attempt.mark_failed("Test failure message")
assert failed.status == OptimizationStatus.FAILED
assert failed.error_message == "Test failure message"
assert failed.completed_at is not None
def test_mark_skipped(self) -> None:
attempt = OptimizationAttempt.create("id", "func", "/file.py")
skipped = attempt.mark_skipped("Already optimized")
assert skipped.status == OptimizationStatus.SKIPPED
assert skipped.error_message == "Already optimized"
class TestAgentStateSnapshot:
def test_create_snapshot(self) -> None:
snapshot = AgentStateSnapshot.create(
agent_id="discovery",
agent_type="DiscoveryAgent",
state="running",
current_task_id="task-123",
context={"extra": "data"},
)
assert snapshot.agent_id == "discovery"
assert snapshot.agent_type == "DiscoveryAgent"
assert snapshot.state == "running"
assert snapshot.current_task_id == "task-123"
assert snapshot.context == {"extra": "data"}
class TestPipelineState:
def test_create_pipeline_state(self) -> None:
state = PipelineState.create(
pipeline_id="pipeline-1",
function_qualified_name="module.func",
file_path="/path/file.py",
)
assert state.pipeline_id == "pipeline-1"
assert state.status == "pending"
assert state.function_qualified_name == "module.func"
class TestStateStore:
@pytest.fixture
def temp_store(self) -> StateStore:
with tempfile.TemporaryDirectory() as tmpdir:
yield StateStore(Path(tmpdir))
def test_store_initialization(self, temp_store: StateStore) -> None:
assert temp_store.db_path.exists()
def test_persist_and_get_optimization_attempt(self, temp_store: StateStore) -> None:
attempt = OptimizationAttempt.create("test-id", "func.name", "/file.py")
temp_store.persist_optimization_attempt(attempt)
retrieved = temp_store.get_optimization_attempt("test-id")
assert retrieved is not None
assert retrieved.attempt_id == "test-id"
assert retrieved.function_qualified_name == "func.name"
def test_get_function_history(self, temp_store: StateStore) -> None:
for i in range(3):
attempt = OptimizationAttempt.create(f"id-{i}", "module.func", "/file.py")
temp_store.persist_optimization_attempt(attempt)
history = temp_store.get_function_history("module.func")
assert len(history) == 3
def test_get_recent_attempts(self, temp_store: StateStore) -> None:
attempt1 = OptimizationAttempt.create("id-1", "func1", "/file1.py")
attempt2 = OptimizationAttempt.create("id-2", "func2", "/file2.py")
temp_store.persist_optimization_attempt(attempt1)
temp_store.persist_optimization_attempt(attempt2)
recent = temp_store.get_recent_attempts(limit=10)
assert len(recent) == 2
def test_was_function_recently_optimized(self, temp_store: StateStore) -> None:
attempt = OptimizationAttempt.create("id", "func", "/file.py", code_hash="hash123")
completed = attempt.mark_completed(2.0, 1000, 500)
temp_store.persist_optimization_attempt(completed)
assert temp_store.was_function_recently_optimized("func", code_hash="hash123")
assert not temp_store.was_function_recently_optimized("other_func")
def test_persist_and_get_agent_state(self, temp_store: StateStore) -> None:
snapshot = AgentStateSnapshot.create("agent-1", "TestAgent", "running")
temp_store.persist_agent_state(snapshot)
retrieved = temp_store.get_agent_state("agent-1")
assert retrieved is not None
assert retrieved.agent_type == "TestAgent"
def test_persist_and_get_pipeline_state(self, temp_store: StateStore) -> None:
state = PipelineState.create("pipe-1", "func", "/file.py")
temp_store.persist_pipeline_state(state)
retrieved = temp_store.get_pipeline_state("pipe-1")
assert retrieved is not None
assert retrieved.function_qualified_name == "func"
class TestOptimizationHistory:
@pytest.fixture
def history(self) -> OptimizationHistory:
with tempfile.TemporaryDirectory() as tmpdir:
store = StateStore(Path(tmpdir))
yield OptimizationHistory(store)
def test_record_and_get_attempts(self, history: OptimizationHistory) -> None:
attempt = OptimizationAttempt.create("id", "func", "/file.py")
history.record_attempt(attempt)
attempts = history.get_function_attempts("func")
assert len(attempts) == 1
def test_should_skip_recently_optimized(self, history: OptimizationHistory) -> None:
for i in range(3):
attempt = OptimizationAttempt.create(f"id-{i}", "func", "/file.py")
completed = attempt.mark_completed(2.0, 1000, 500)
history.record_attempt(completed)
should_skip, reason = history.should_skip_function("func")
assert should_skip
assert "successfully optimized" in reason
def test_should_skip_repeatedly_failed(self, history: OptimizationHistory) -> None:
for i in range(3):
attempt = OptimizationAttempt.create(f"id-{i}", "func", "/file.py")
failed = attempt.mark_failed("error")
history.record_attempt(failed)
should_skip, reason = history.should_skip_function("func")
assert should_skip
assert "failed" in reason
def test_get_best_speedup(self, history: OptimizationHistory) -> None:
for speedup in [1.5, 2.0, 1.8]:
attempt = OptimizationAttempt.create(f"id-{speedup}", "func", "/file.py")
completed = attempt.mark_completed(speedup, 1000, int(1000 / speedup))
history.record_attempt(completed)
best = history.get_best_speedup("func")
assert best == 2.0
def test_get_statistics(self, history: OptimizationHistory) -> None:
attempt1 = OptimizationAttempt.create("id-1", "func1", "/file.py")
attempt2 = OptimizationAttempt.create("id-2", "func2", "/file.py")
history.record_attempt(attempt1.mark_completed(2.0, 1000, 500))
history.record_attempt(attempt2.mark_failed("error"))
stats = history.get_statistics()
assert stats["total_attempts"] == 2
assert stats["completed"] == 1
assert stats["failed"] == 1