Merge branch 'main' into feat/hypothesis-tests
Some checks are pending
CodeFlash / Optimize new Python code (pull_request) Waiting to run
E2E - Async / async-optimization (pull_request) Waiting to run
E2E - Bubble Sort Benchmark / benchmark-bubble-sort-optimization (pull_request) Waiting to run
E2E - Bubble Sort Pytest (No Git) / bubble-sort-optimization-pytest-no-git (pull_request) Waiting to run
E2E - Bubble Sort Unittest / bubble-sort-optimization-unittest (pull_request) Waiting to run
Coverage E2E / end-to-end-test-coverage (pull_request) Waiting to run
E2E - Futurehouse Structure / futurehouse-structure (pull_request) Waiting to run
PR Labeler / label-workflow-changes (pull_request) Waiting to run
/ Run pr agent on every pull request, respond to user comments (pull_request) Waiting to run
Lint / Run pre-commit hooks (pull_request) Waiting to run
unit-tests / unit-tests (3.10) (pull_request) Waiting to run
unit-tests / unit-tests (3.11) (pull_request) Waiting to run
unit-tests / unit-tests (3.12) (pull_request) Waiting to run
unit-tests / unit-tests (3.13) (pull_request) Waiting to run
unit-tests / unit-tests (3.14) (pull_request) Waiting to run
E2E - Init Optimization / init-optimization (pull_request) Waiting to run
E2E - Topological Sort (Worktree) / topological-sort-worktree-optimization (pull_request) Waiting to run
E2E - Tracer Replay / tracer-replay (pull_request) Waiting to run
Mypy Type Checking for CLI / type-check-cli (pull_request) Waiting to run
unit-tests / unit-tests (3.9) (pull_request) Waiting to run
windows-unit-tests / windows-unit-tests (pull_request) Waiting to run

This commit is contained in:
Kevin Turcios 2025-10-30 16:30:21 -05:00 committed by GitHub
commit ba04f88496
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 275 additions and 155 deletions

View file

@ -52,10 +52,11 @@ def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption(
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
)
# These options are ignored when --codeflash-trace is used
for option, action, default, help_text in benchmark_options:
help_suffix = " (ignored when --codeflash-trace is used)"
parser.addoption(option, action=action, default=default, help=help_text + help_suffix)
# Only add benchmark options if pytest-benchmark is not installed for backward compatibility with existing pytest-benchmark setup
if not PYTEST_BENCHMARK_INSTALLED:
for option, action, default, help_text in benchmark_options:
help_suffix = " (ignored when --codeflash-trace is used)"
parser.addoption(option, action=action, default=default, help=help_text + help_suffix)
@pytest.fixture

View file

@ -80,9 +80,16 @@ def paneled_text(
console.print(panel)
def code_print(code_str: str, file_name: Optional[str] = None, function_name: Optional[str] = None) -> None:
def code_print(
code_str: str,
file_name: Optional[str] = None,
function_name: Optional[str] = None,
lsp_message_id: Optional[str] = None,
) -> None:
if is_LSP_enabled():
lsp_log(LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name))
lsp_log(
LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name, message_id=lsp_message_id)
)
return
"""Print code with syntax highlighting."""
from rich.syntax import Syntax

View file

@ -360,3 +360,19 @@ def exit_with_message(message: str, *, error_on_exit: bool = False) -> None:
paneled_text(message, panel_args={"style": "red"})
sys.exit(1 if error_on_exit else 0)
def extract_unique_errors(pytest_output: str) -> set[str]:
unique_errors = set()
# Regex pattern to match error lines:
# - Start with 'E' followed by optional whitespace
# - Capture the actual error message
pattern = r"^E\s+(.*)$"
for error_message in re.findall(pattern, pytest_output, re.MULTILINE):
error_message = error_message.strip() # noqa: PLW2901
if error_message:
unique_errors.add(error_message)
return unique_errors

View file

@ -2,7 +2,9 @@ from __future__ import annotations
import asyncio
import contextlib
import contextvars
import os
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional
@ -27,8 +29,8 @@ from codeflash.discovery.functions_to_optimize import (
get_functions_within_git_diff,
)
from codeflash.either import is_successful
from codeflash.lsp.features.perform_optimization import sync_perform_optimization
from codeflash.lsp.server import CodeflashLanguageServer
from codeflash.lsp.features.perform_optimization import get_cancelled_reponse, sync_perform_optimization
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol
if TYPE_CHECKING:
from argparse import Namespace
@ -47,6 +49,7 @@ class OptimizableFunctionsParams:
class FunctionOptimizationInitParams:
textDocument: types.TextDocumentIdentifier # noqa: N815
functionName: str # noqa: N815
task_id: str
@dataclass
@ -84,30 +87,24 @@ class WriteConfigParams:
config: any
server = CodeflashLanguageServer("codeflash-language-server", "v1.0")
server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
@server.feature("getOptimizableFunctionsInCurrentDiff")
def get_functions_in_current_git_diff(
server: CodeflashLanguageServer, _params: OptimizableFunctionsParams
) -> dict[str, str | dict[str, list[str]]]:
def get_functions_in_current_git_diff(_params: OptimizableFunctionsParams) -> dict[str, str | dict[str, list[str]]]:
functions = get_functions_within_git_diff(uncommitted_changes=True)
file_to_qualified_names = _group_functions_by_file(server, functions)
file_to_qualified_names = _group_functions_by_file(functions)
return {"functions": file_to_qualified_names, "status": "success"}
@server.feature("getOptimizableFunctionsInCommit")
def get_functions_in_commit(
server: CodeflashLanguageServer, params: OptimizableFunctionsInCommitParams
) -> dict[str, str | dict[str, list[str]]]:
def get_functions_in_commit(params: OptimizableFunctionsInCommitParams) -> dict[str, str | dict[str, list[str]]]:
functions = get_functions_inside_a_commit(params.commit_hash)
file_to_qualified_names = _group_functions_by_file(server, functions)
file_to_qualified_names = _group_functions_by_file(functions)
return {"functions": file_to_qualified_names, "status": "success"}
def _group_functions_by_file(
server: CodeflashLanguageServer, functions: dict[str, list[FunctionToOptimize]]
) -> dict[str, list[str]]:
def _group_functions_by_file(functions: dict[str, list[FunctionToOptimize]]) -> dict[str, list[str]]:
file_to_funcs_to_optimize, _ = filter_functions(
modified_functions=functions,
tests_root=server.optimizer.test_cfg.tests_root,
@ -123,9 +120,7 @@ def _group_functions_by_file(
@server.feature("getOptimizableFunctions")
def get_optimizable_functions(
server: CodeflashLanguageServer, params: OptimizableFunctionsParams
) -> dict[str, list[str]]:
def get_optimizable_functions(params: OptimizableFunctionsParams) -> dict[str, list[str]]:
document_uri = params.textDocument.uri
document = server.workspace.get_text_document(document_uri)
@ -172,7 +167,7 @@ def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]:
@server.feature("writeConfig")
def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) -> dict[str, any]:
def write_config(params: WriteConfigParams) -> dict[str, any]:
cfg = params.config
cfg_file = Path(params.config_file) if params.config_file else None
@ -196,7 +191,7 @@ def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) ->
@server.feature("getConfigSuggestions")
def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> dict[str, any]:
def get_config_suggestions(_params: any) -> dict[str, any]:
module_root_suggestions, default_module_root = get_suggestions(CommonSections.module_root)
tests_root_suggestions, default_tests_root = get_suggestions(CommonSections.tests_root)
test_framework_suggestions, default_test_framework = get_suggestions(CommonSections.test_framework)
@ -212,9 +207,9 @@ def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> di
# should be called the first thing to initialize and validate the project
@server.feature("initProject")
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]:
def init_project(params: ValidateProjectParams) -> dict[str, str]:
# Always process args in the init project, the extension can call
server.args_processed_before = False
server.initialized = False
pyproject_toml_path: Path | None = getattr(params, "config_file", None) or getattr(server.args, "config_file", None)
if pyproject_toml_path is not None:
@ -255,19 +250,16 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
"existingConfig": config,
}
args = process_args(server)
args = _init()
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path, "root": root}
def _initialize_optimizer_if_api_key_is_valid(
server: CodeflashLanguageServer, api_key: Optional[str] = None
) -> dict[str, str]:
def _initialize_optimizer_if_api_key_is_valid(api_key: Optional[str] = None) -> dict[str, str]:
key_check_result = _check_api_key_validity(api_key)
if key_check_result.get("status") != "success":
return key_check_result
_initialize_optimizer(server)
_init()
return key_check_result
@ -283,48 +275,56 @@ def _check_api_key_validity(api_key: Optional[str]) -> dict[str, str]:
return {"status": "success", "user_id": user_id}
def _initialize_optimizer(server: CodeflashLanguageServer) -> None:
def _initialize_optimizer(args: Namespace) -> None:
from codeflash.optimization.optimizer import Optimizer
new_args = process_args(server)
if not server.optimizer:
server.optimizer = Optimizer(new_args)
server.optimizer = Optimizer(args)
def process_args(server: CodeflashLanguageServer) -> Namespace:
if server.args_processed_before:
return server.args
def process_args() -> Namespace:
new_args = process_pyproject_config(server.args)
server.args = new_args
server.args_processed_before = True
return new_args
def _init() -> Namespace:
if server.initialized:
return server.args
new_args = process_args()
_initialize_optimizer(new_args)
server.initialized = True
return new_args
@server.feature("apiKeyExistsAndValid")
def check_api_key(server: CodeflashLanguageServer, _params: any) -> dict[str, str]:
def check_api_key(_params: any) -> dict[str, str]:
try:
return _initialize_optimizer_if_api_key_is_valid(server)
return _initialize_optimizer_if_api_key_is_valid()
except Exception:
return {"status": "error", "message": "something went wrong while validating the api key"}
@server.feature("provideApiKey")
def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams) -> dict[str, str]:
def provide_api_key(params: ProvideApiKeyParams) -> dict[str, str]:
try:
api_key = params.api_key
if not api_key.startswith("cf-"):
return {"status": "error", "message": "Api key is not valid"}
# # clear cache to ensure the new api key is used
# clear cache to ensure the new api key is used
get_codeflash_api_key.cache_clear()
get_user_id.cache_clear()
key_check_result = _check_api_key_validity(api_key)
if key_check_result.get("status") != "success":
return key_check_result
user_id = key_check_result["user_id"]
result = save_api_key_to_rc(api_key)
# initialize optimizer with the new api key
_initialize_optimizer(server)
_init()
if not is_successful(result):
return {"status": "error", "message": result.failure()}
return {"status": "success", "message": "Api key saved successfully", "user_id": user_id} # noqa: TRY300
@ -332,85 +332,106 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
return {"status": "error", "message": "something went wrong while saving the api key"}
@contextlib.contextmanager
def execution_context(**kwargs: str) -> None:
"""Temporarily set context values for the current async task."""
# Create a fresh copy per use
current = {**server.execution_context_vars.get(), **kwargs}
token = server.execution_context_vars.set(current)
try:
yield
finally:
server.execution_context_vars.reset(token)
@server.feature("cleanupCurrentOptimizerSession")
def cleanup_optimizer(_params: any) -> dict[str, str]:
if not server.cleanup_the_optimizer():
return {"status": "error", "message": "Failed to cleanup optimizer"}
return {"status": "success"}
@server.feature("initializeFunctionOptimization")
def initialize_function_optimization(
server: CodeflashLanguageServer, params: FunctionOptimizationInitParams
) -> dict[str, str]:
document_uri = params.textDocument.uri
document = server.workspace.get_text_document(document_uri)
file_path = Path(document.path)
def initialize_function_optimization(params: FunctionOptimizationInitParams) -> dict[str, str]:
with execution_context(task_id=getattr(params, "task_id", None)):
document_uri = params.textDocument.uri
document = server.workspace.get_text_document(document_uri)
file_path = Path(document.path)
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info")
server.show_message_log(
f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info"
)
if server.optimizer is None:
_initialize_optimizer_if_api_key_is_valid(server)
if server.optimizer is None:
_initialize_optimizer_if_api_key_is_valid()
server.optimizer.args.file = file_path
server.optimizer.args.function = params.functionName
server.optimizer.args.previous_checkpoint_functions = False
server.optimizer.args.file = file_path
server.optimizer.args.function = params.functionName
server.optimizer.args.previous_checkpoint_functions = False
server.optimizer.worktree_mode()
server.optimizer.worktree_mode()
server.show_message_log(
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
)
server.show_message_log(
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
)
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
if count == 0:
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
server.cleanup_the_optimizer()
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
if count == 0:
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
server.cleanup_the_optimizer()
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
fto = optimizable_funcs.popitem()[1][0]
fto = optimizable_funcs.popitem()[1][0]
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
if not module_prep_result:
return {
"functionName": params.functionName,
"status": "error",
"message": "Failed to prepare module for optimization",
}
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
if not module_prep_result:
return {
"functionName": params.functionName,
"status": "error",
"message": "Failed to prepare module for optimization",
}
validated_original_code, original_module_ast = module_prep_result
validated_original_code, original_module_ast = module_prep_result
function_optimizer = server.optimizer.create_function_optimizer(
fto,
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
original_module_ast=original_module_ast,
original_module_path=fto.file_path,
function_to_tests={},
)
function_optimizer = server.optimizer.create_function_optimizer(
fto,
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
original_module_ast=original_module_ast,
original_module_path=fto.file_path,
function_to_tests={},
)
server.optimizer.current_function_optimizer = function_optimizer
if not function_optimizer:
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
server.optimizer.current_function_optimizer = function_optimizer
if not function_optimizer:
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
initialization_result = function_optimizer.can_be_optimized()
if not is_successful(initialization_result):
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
initialization_result = function_optimizer.can_be_optimized()
if not is_successful(initialization_result):
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
server.current_optimization_init_result = initialization_result.unwrap()
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
server.current_optimization_init_result = initialization_result.unwrap()
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
files = [function_optimizer.function_to_optimize.file_path]
files = [function_optimizer.function_to_optimize.file_path]
_, _, original_helpers = server.current_optimization_init_result
files.extend([str(helper_path) for helper_path in original_helpers])
_, _, original_helpers = server.current_optimization_init_result
files.extend([str(helper_path) for helper_path in original_helpers])
return {"functionName": params.functionName, "status": "success", "files_inside_context": files}
return {"functionName": params.functionName, "status": "success", "files_inside_context": files}
@server.feature("performFunctionOptimization")
async def perform_function_optimization(
server: CodeflashLanguageServer, params: FunctionOptimizationParams
) -> dict[str, str]:
loop = asyncio.get_running_loop()
try:
result = await loop.run_in_executor(None, sync_perform_optimization, server, params)
except asyncio.CancelledError:
return {"status": "canceled", "message": "Task was canceled"}
else:
return result
finally:
server.cleanup_the_optimizer()
async def perform_function_optimization(params: FunctionOptimizationParams) -> dict[str, str]:
with execution_context(task_id=getattr(params, "task_id", None)):
loop = asyncio.get_running_loop()
cancel_event = threading.Event()
try:
ctx = contextvars.copy_context()
return await loop.run_in_executor(None, ctx.run, sync_perform_optimization, server, cancel_event, params)
except asyncio.CancelledError:
cancel_event.set()
return get_cancelled_reponse()
finally:
server.cleanup_the_optimizer()

View file

@ -1,13 +1,29 @@
from __future__ import annotations
import contextlib
import os
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import code_print
from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree
from codeflash.either import is_successful
from codeflash.lsp.server import CodeflashLanguageServer
if TYPE_CHECKING:
import threading
from codeflash.lsp.server import CodeflashLanguageServer
def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[str, str]: # noqa: ANN001
def get_cancelled_reponse() -> dict[str, str]:
return {"status": "canceled", "message": "Task was canceled"}
def abort_if_cancelled(cancel_event: threading.Event) -> None:
if cancel_event.is_set():
raise RuntimeError("cancelled")
def sync_perform_optimization(server: CodeflashLanguageServer, cancel_event: threading.Event, params) -> dict[str, str]: # noqa
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
should_run_experiment, code_context, original_helper_code = server.current_optimization_init_result
function_optimizer = server.optimizer.current_function_optimizer
@ -18,6 +34,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[s
file_name=current_function.file_path,
function_name=current_function.function_name,
)
abort_if_cancelled(cancel_event)
optimizable_funcs = {current_function.file_path: [current_function]}
@ -26,9 +43,11 @@ def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[s
function_to_tests, _num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
function_optimizer.function_to_tests = function_to_tests
abort_if_cancelled(cancel_event)
test_setup_result = function_optimizer.generate_and_instrument_tests(
code_context, should_run_experiment=should_run_experiment
)
abort_if_cancelled(cancel_event)
if not is_successful(test_setup_result):
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
(
@ -52,6 +71,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[s
original_conftest_content=original_conftest_content,
)
abort_if_cancelled(cancel_event)
if not is_successful(baseline_setup_result):
return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()}
@ -76,6 +96,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[s
concolic_test_str=concolic_test_str,
)
abort_if_cancelled(cancel_event)
if not best_optimization:
server.show_message_log(
f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning"
@ -93,6 +114,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[s
server.optimizer.current_worktree, relative_file_paths, function_to_optimize_qualified_name
)
abort_if_cancelled(cancel_event)
if not patch_path:
return {
"functionName": params.functionName,

View file

@ -3,13 +3,15 @@ from __future__ import annotations
import logging
import sys
from dataclasses import dataclass
from typing import Any, Callable
from typing import Any, Callable, Optional
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.lsp.lsp_message import LspTextMessage, message_delimiter
from codeflash.lsp.lsp_message import LSPMessageId, LspTextMessage, message_delimiter
root_logger = None
message_id_prefix = "id:"
@dataclass
class LspMessageTags:
@ -18,6 +20,7 @@ class LspMessageTags:
lsp: bool = False # lsp (lsp only)
force_lsp: bool = False # force_lsp (you can use this to force a message to be sent to the LSP even if the level is not supported)
loading: bool = False # loading (you can use this to indicate that the message is a loading message)
message_id: Optional[LSPMessageId] = None # example: id:best_candidate
highlight: bool = False # highlight (you can use this to highlight the message by wrapping it in ``)
h1: bool = False # h1
h2: bool = False # h2
@ -52,24 +55,27 @@ def extract_tags(msg: str) -> tuple[LspMessageTags, str]:
tags = {tag.strip() for tag in tags_str.split(",")}
message_tags = LspMessageTags()
# manually check and set to avoid repeated membership tests
if "lsp" in tags:
message_tags.lsp = True
if "!lsp" in tags:
message_tags.not_lsp = True
if "force_lsp" in tags:
message_tags.force_lsp = True
if "loading" in tags:
message_tags.loading = True
if "highlight" in tags:
message_tags.highlight = True
if "h1" in tags:
message_tags.h1 = True
if "h2" in tags:
message_tags.h2 = True
if "h3" in tags:
message_tags.h3 = True
if "h4" in tags:
message_tags.h4 = True
for tag in tags:
if tag.startswith(message_id_prefix):
message_tags.message_id = LSPMessageId(tag[len(message_id_prefix) :]).value
elif tag == "lsp":
message_tags.lsp = True
elif tag == "!lsp":
message_tags.not_lsp = True
elif tag == "force_lsp":
message_tags.force_lsp = True
elif tag == "loading":
message_tags.loading = True
elif tag == "highlight":
message_tags.highlight = True
elif tag == "h1":
message_tags.h1 = True
elif tag == "h2":
message_tags.h2 = True
elif tag == "h3":
message_tags.h3 = True
elif tag == "h4":
message_tags.h4 = True
return message_tags, content
return LspMessageTags(), msg
@ -110,11 +116,15 @@ def enhanced_log(
actual_log_fn(clean_msg, *args, **kwargs)
return
if not lsp_enabled:
# it's for LSP and LSP is disabled
return
# ---- LSP logging path ----
if is_normal_text_message:
clean_msg = add_heading_tags(clean_msg, tags)
clean_msg = add_highlight_tags(clean_msg, tags)
clean_msg = LspTextMessage(text=clean_msg, takes_time=tags.loading).serialize()
clean_msg = LspTextMessage(text=clean_msg, takes_time=tags.loading, message_id=tags.message_id).serialize()
actual_log_fn(clean_msg, *args, **kwargs)

View file

@ -1,11 +1,12 @@
from __future__ import annotations
import enum
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Optional
from codeflash.lsp.helpers import replace_quotes_with_backticks, simplify_worktree_paths
from codeflash.lsp.helpers import is_LSP_enabled, replace_quotes_with_backticks, simplify_worktree_paths
json_primitive_types = (str, float, int, bool)
max_code_lines_before_collapse = 45
@ -14,10 +15,17 @@ max_code_lines_before_collapse = 45
message_delimiter = "\u241f"
# allow the client to know which message it is receiving
class LSPMessageId(enum.Enum):
BEST_CANDIDATE = "best_candidate"
CANDIDATE = "candidate"
@dataclass
class LspMessage:
# to show a loading indicator if the operation is taking time like generating candidates or tests
takes_time: bool = False
message_id: Optional[str] = None
def _loop_through(self, obj: Any) -> Any: # noqa: ANN401
if isinstance(obj, list):
@ -34,8 +42,14 @@ class LspMessage:
raise NotImplementedError
def serialize(self) -> str:
if not is_LSP_enabled():
return ""
from codeflash.lsp.beta import server
execution_ctx = server.execution_context_vars.get()
current_task_id = execution_ctx.get("task_id", None)
data = self._loop_through(asdict(self))
ordered = {"type": self.type(), **data}
ordered = {"type": self.type(), "task_id": current_task_id, **data}
return message_delimiter + json.dumps(ordered) + message_delimiter

View file

@ -1,6 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import contextvars
from typing import TYPE_CHECKING
from lsprotocol.types import LogMessageParams, MessageType
from pygls.lsp.server import LanguageServer
@ -18,12 +19,16 @@ class CodeflashLanguageServerProtocol(LanguageServerProtocol):
class CodeflashLanguageServer(LanguageServer):
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
super().__init__(*args, **kwargs)
def __init__(self, name: str, version: str, protocol_cls: type[LanguageServerProtocol]) -> None:
super().__init__(name, version, protocol_cls=protocol_cls)
self.initialized: bool = False
self.optimizer: Optimizer | None = None
self.args_processed_before: bool = False
self.args = None
self.current_optimization_init_result: tuple[bool, CodeOptimizationContext, dict[Path, str]] | None = None
self.execution_context_vars: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar(
"execution_context_vars",
default={}, # noqa: B039
)
def prepare_optimizer_arguments(self, config_file: Path) -> None:
from codeflash.cli_cmds.cli import parse_args
@ -58,10 +63,10 @@ class CodeflashLanguageServer(LanguageServer):
log_params = LogMessageParams(type=lsp_message_type, message=message)
self.protocol.notify("window/logMessage", log_params)
def cleanup_the_optimizer(self) -> None:
def cleanup_the_optimizer(self) -> bool:
self.current_optimization_init_result = None
if not self.optimizer:
return
return False
try:
self.optimizer.cleanup_temporary_paths()
# restore args and test cfg
@ -72,6 +77,8 @@ class CodeflashLanguageServer(LanguageServer):
self.optimizer.current_function_optimizer = None
except Exception:
self.show_message_log("Failed to cleanup optimizer", "Error")
return False
return True
def shutdown(self) -> None:
"""Gracefully shutdown the server."""

View file

@ -30,10 +30,10 @@ from codeflash.code_utils.code_replacer import (
replace_function_definitions_in_module,
)
from codeflash.code_utils.code_utils import (
ImportErrorPattern,
cleanup_paths,
create_rank_dictionary_compact,
diff_length,
extract_unique_errors,
file_name_from_test_module_name,
get_run_tmp_file,
module_name_from_file_path,
@ -66,7 +66,7 @@ from codeflash.context.unused_definition_remover import detect_unused_helper_fun
from codeflash.discovery.functions_to_optimize import was_function_previously_optimized
from codeflash.either import Failure, Success, is_successful
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import (
BestOptimization,
@ -532,7 +532,11 @@ class FunctionOptimizer:
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
logger.info(f"h3|Optimization candidate {candidate_index}/{processor.candidate_len}:")
code_print(candidate.source_code.flat, file_name=f"candidate_{candidate_index}.py")
code_print(
candidate.source_code.flat,
file_name=f"candidate_{candidate_index}.py",
lsp_message_id=LSPMessageId.CANDIDATE.value,
)
# map ast normalized code to diff len, unnormalized code
# map opt id to the shortest unnormalized code
try:
@ -1344,6 +1348,7 @@ class FunctionOptimizer:
best_optimization.candidate.source_code.flat,
file_name="best_candidate.py",
function_name=self.function_to_optimize.function_name,
lsp_message_id=LSPMessageId.BEST_CANDIDATE.value,
)
processed_benchmark_info = None
if self.args.benchmark:
@ -1629,17 +1634,20 @@ class FunctionOptimizer:
)
if not behavioral_results:
logger.warning(
f"force_lsp|Couldn't run any tests for original function {self.function_to_optimize.function_name}. SKIPPING OPTIMIZING THIS FUNCTION."
f"force_lsp|Couldn't run any tests for original function {self.function_to_optimize.function_name}. Skipping optimization."
)
console.rule()
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
if not coverage_critic(coverage_results, self.args.test_framework):
did_pass_all_tests = all(result.did_pass for result in behavioral_results)
if not did_pass_all_tests:
return Failure("Tests failed to pass for the original code.")
return Failure(
f"Test coverage is {coverage_results.coverage}%, which is below the required threshold of {COVERAGE_THRESHOLD}%."
)
if test_framework == "pytest":
with progress_bar("Running line profiling to identify performance bottlenecks..."):
with progress_bar("Running line profiler to identify performance bottlenecks..."):
line_profile_results = self.line_profiler_step(
code_context=code_context, original_helper_code=original_helper_code, candidate_index=0
)
@ -1939,7 +1947,7 @@ class FunctionOptimizer:
*,
enable_coverage: bool = False,
pytest_min_loops: int = 5,
pytest_max_loops: int = 100_000,
pytest_max_loops: int = 250,
code_context: CodeOptimizationContext | None = None,
unittest_loop_index: int | None = None,
line_profiler_output_file: Path | None = None,
@ -1997,12 +2005,19 @@ class FunctionOptimizer:
f"stdout: {run_result.stdout}\n"
f"stderr: {run_result.stderr}\n"
)
if "ModuleNotFoundError" in run_result.stdout:
unique_errors = extract_unique_errors(run_result.stdout)
if unique_errors:
from rich.text import Text
match = ImportErrorPattern.search(run_result.stdout).group()
panel = Panel(Text.from_markup(f"⚠️ {match} ", style="bold red"), expand=False)
console.print(panel)
for error in unique_errors:
if is_LSP_enabled():
lsp_log(LspCodeMessage(code=error, file_name="errors"))
else:
panel = Panel(Text.from_markup(f"⚠️ {error} ", style="bold red"), expand=False)
console.print(panel)
if testing_type in {TestingMode.BEHAVIOR, TestingMode.PERFORMANCE}:
results, coverage_results = parse_test_results(
test_xml_path=result_file_path,
@ -2101,6 +2116,13 @@ class FunctionOptimizer:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
# this will happen when a timeoutexpired exception happens
if isinstance(line_profile_results, TestResults) and not line_profile_results.test_results:
logger.warning(
f"Timeout occurred while running line profiler for original function {self.function_to_optimize.function_name}"
)
# set default value for line profiler results
return {"timings": {}, "unit": 0, "str_out": ""}
if line_profile_results["str_out"] == "":
logger.warning(
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"

View file

@ -98,8 +98,8 @@ def _apply_deterministic_patches() -> None:
_original_random = random.random
# Fixed deterministic values
fixed_timestamp = 1609459200.0 # 2021-01-01 00:00:00 UTC
fixed_datetime = datetime.datetime(2021, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
fixed_timestamp = 1761717605.108106
fixed_datetime = datetime.datetime(2021, 1, 1, 2, 5, 10, tzinfo=datetime.timezone.utc)
fixed_uuid = uuid.UUID("12345678-1234-5678-9abc-123456789012")
# Counter for perf_counter to maintain relative timing

View file

@ -48,8 +48,8 @@ class TestDeterministicPatches:
original_os_urandom = os.urandom
# Create deterministic implementations (matching pytest_plugin.py)
fixed_timestamp = 1609459200.0 # 2021-01-01 00:00:00 UTC
fixed_datetime = datetime.datetime(2021, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
fixed_timestamp = 1761717605.108106
fixed_datetime = datetime.datetime(2021, 1, 1, 2, 5, 10, tzinfo=datetime.timezone.utc)
fixed_uuid = uuid.UUID("12345678-1234-5678-9abc-123456789012")
# Counter for perf_counter
@ -159,7 +159,7 @@ class TestDeterministicPatches:
def test_time_time_deterministic(self, setup_deterministic_environment):
"""Test that time.time() returns a fixed deterministic value."""
expected_timestamp = 1609459200.0 # 2021-01-01 00:00:00 UTC
expected_timestamp = 1761717605.108106
# Call multiple times and verify consistent results
result1 = time.time()
@ -311,7 +311,7 @@ class TestDeterministicPatches:
result1 = mock_now()
result2 = mock_utcnow()
expected_dt = datetime.datetime(2021, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
expected_dt = datetime.datetime(2021, 1, 1, 2, 5, 10, tzinfo=datetime.timezone.utc)
assert result1 == expected_dt
assert result2 == expected_dt
@ -355,7 +355,7 @@ class TestDeterministicPatches:
def test_patches_applied_correctly(self, setup_deterministic_environment):
"""Test that patches are applied correctly."""
# Test that functions return expected deterministic values
assert time.time() == 1609459200.0
assert time.time() == 1761717605.108106
assert uuid.uuid4() == uuid.UUID("12345678-1234-5678-9abc-123456789012")
assert random.random() == 0.123456789
assert os.urandom(4) == b"\x42\x42\x42\x42"
@ -378,7 +378,7 @@ class TestDeterministicPatches:
# Test with different timezone
utc_tz = datetime.timezone.utc
result_with_tz = mock_now(utc_tz)
expected_with_tz = datetime.datetime(2021, 1, 1, 0, 0, 0, tzinfo=utc_tz)
expected_with_tz = datetime.datetime(2021, 1, 1, 2, 5, 10, tzinfo=datetime.timezone.utc)
assert result_with_tz == expected_with_tz
def test_integration_with_actual_optimization_scenario(self, setup_deterministic_environment):