Merge branch 'main' into feat/hypothesis-tests
Some checks are pending
CodeFlash / Optimize new Python code (pull_request) Waiting to run
E2E - Async / async-optimization (pull_request) Waiting to run
E2E - Bubble Sort Benchmark / benchmark-bubble-sort-optimization (pull_request) Waiting to run
E2E - Bubble Sort Pytest (No Git) / bubble-sort-optimization-pytest-no-git (pull_request) Waiting to run
E2E - Bubble Sort Unittest / bubble-sort-optimization-unittest (pull_request) Waiting to run
Coverage E2E / end-to-end-test-coverage (pull_request) Waiting to run
E2E - Futurehouse Structure / futurehouse-structure (pull_request) Waiting to run
PR Labeler / label-workflow-changes (pull_request) Waiting to run
/ Run pr agent on every pull request, respond to user comments (pull_request) Waiting to run
Lint / Run pre-commit hooks (pull_request) Waiting to run
unit-tests / unit-tests (3.10) (pull_request) Waiting to run
unit-tests / unit-tests (3.11) (pull_request) Waiting to run
unit-tests / unit-tests (3.12) (pull_request) Waiting to run
unit-tests / unit-tests (3.13) (pull_request) Waiting to run
unit-tests / unit-tests (3.14) (pull_request) Waiting to run
E2E - Init Optimization / init-optimization (pull_request) Waiting to run
E2E - Topological Sort (Worktree) / topological-sort-worktree-optimization (pull_request) Waiting to run
E2E - Tracer Replay / tracer-replay (pull_request) Waiting to run
Mypy Type Checking for CLI / type-check-cli (pull_request) Waiting to run
unit-tests / unit-tests (3.9) (pull_request) Waiting to run
windows-unit-tests / windows-unit-tests (pull_request) Waiting to run
Some checks are pending
CodeFlash / Optimize new Python code (pull_request) Waiting to run
E2E - Async / async-optimization (pull_request) Waiting to run
E2E - Bubble Sort Benchmark / benchmark-bubble-sort-optimization (pull_request) Waiting to run
E2E - Bubble Sort Pytest (No Git) / bubble-sort-optimization-pytest-no-git (pull_request) Waiting to run
E2E - Bubble Sort Unittest / bubble-sort-optimization-unittest (pull_request) Waiting to run
Coverage E2E / end-to-end-test-coverage (pull_request) Waiting to run
E2E - Futurehouse Structure / futurehouse-structure (pull_request) Waiting to run
PR Labeler / label-workflow-changes (pull_request) Waiting to run
/ Run pr agent on every pull request, respond to user comments (pull_request) Waiting to run
Lint / Run pre-commit hooks (pull_request) Waiting to run
unit-tests / unit-tests (3.10) (pull_request) Waiting to run
unit-tests / unit-tests (3.11) (pull_request) Waiting to run
unit-tests / unit-tests (3.12) (pull_request) Waiting to run
unit-tests / unit-tests (3.13) (pull_request) Waiting to run
unit-tests / unit-tests (3.14) (pull_request) Waiting to run
E2E - Init Optimization / init-optimization (pull_request) Waiting to run
E2E - Topological Sort (Worktree) / topological-sort-worktree-optimization (pull_request) Waiting to run
E2E - Tracer Replay / tracer-replay (pull_request) Waiting to run
Mypy Type Checking for CLI / type-check-cli (pull_request) Waiting to run
unit-tests / unit-tests (3.9) (pull_request) Waiting to run
windows-unit-tests / windows-unit-tests (pull_request) Waiting to run
This commit is contained in:
commit
ba04f88496
11 changed files with 275 additions and 155 deletions
|
|
@ -52,10 +52,11 @@ def pytest_addoption(parser: pytest.Parser) -> None:
|
|||
parser.addoption(
|
||||
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
|
||||
)
|
||||
# These options are ignored when --codeflash-trace is used
|
||||
for option, action, default, help_text in benchmark_options:
|
||||
help_suffix = " (ignored when --codeflash-trace is used)"
|
||||
parser.addoption(option, action=action, default=default, help=help_text + help_suffix)
|
||||
# Only add benchmark options if pytest-benchmark is not installed for backward compatibility with existing pytest-benchmark setup
|
||||
if not PYTEST_BENCHMARK_INSTALLED:
|
||||
for option, action, default, help_text in benchmark_options:
|
||||
help_suffix = " (ignored when --codeflash-trace is used)"
|
||||
parser.addoption(option, action=action, default=default, help=help_text + help_suffix)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -80,9 +80,16 @@ def paneled_text(
|
|||
console.print(panel)
|
||||
|
||||
|
||||
def code_print(code_str: str, file_name: Optional[str] = None, function_name: Optional[str] = None) -> None:
|
||||
def code_print(
|
||||
code_str: str,
|
||||
file_name: Optional[str] = None,
|
||||
function_name: Optional[str] = None,
|
||||
lsp_message_id: Optional[str] = None,
|
||||
) -> None:
|
||||
if is_LSP_enabled():
|
||||
lsp_log(LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name))
|
||||
lsp_log(
|
||||
LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name, message_id=lsp_message_id)
|
||||
)
|
||||
return
|
||||
"""Print code with syntax highlighting."""
|
||||
from rich.syntax import Syntax
|
||||
|
|
|
|||
|
|
@ -360,3 +360,19 @@ def exit_with_message(message: str, *, error_on_exit: bool = False) -> None:
|
|||
paneled_text(message, panel_args={"style": "red"})
|
||||
|
||||
sys.exit(1 if error_on_exit else 0)
|
||||
|
||||
|
||||
def extract_unique_errors(pytest_output: str) -> set[str]:
|
||||
unique_errors = set()
|
||||
|
||||
# Regex pattern to match error lines:
|
||||
# - Start with 'E' followed by optional whitespace
|
||||
# - Capture the actual error message
|
||||
pattern = r"^E\s+(.*)$"
|
||||
|
||||
for error_message in re.findall(pattern, pytest_output, re.MULTILINE):
|
||||
error_message = error_message.strip() # noqa: PLW2901
|
||||
if error_message:
|
||||
unique_errors.add(error_message)
|
||||
|
||||
return unique_errors
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import contextvars
|
||||
import os
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
|
@ -27,8 +29,8 @@ from codeflash.discovery.functions_to_optimize import (
|
|||
get_functions_within_git_diff,
|
||||
)
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.lsp.features.perform_optimization import sync_perform_optimization
|
||||
from codeflash.lsp.server import CodeflashLanguageServer
|
||||
from codeflash.lsp.features.perform_optimization import get_cancelled_reponse, sync_perform_optimization
|
||||
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
|
@ -47,6 +49,7 @@ class OptimizableFunctionsParams:
|
|||
class FunctionOptimizationInitParams:
|
||||
textDocument: types.TextDocumentIdentifier # noqa: N815
|
||||
functionName: str # noqa: N815
|
||||
task_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -84,30 +87,24 @@ class WriteConfigParams:
|
|||
config: any
|
||||
|
||||
|
||||
server = CodeflashLanguageServer("codeflash-language-server", "v1.0")
|
||||
server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
|
||||
|
||||
|
||||
@server.feature("getOptimizableFunctionsInCurrentDiff")
|
||||
def get_functions_in_current_git_diff(
|
||||
server: CodeflashLanguageServer, _params: OptimizableFunctionsParams
|
||||
) -> dict[str, str | dict[str, list[str]]]:
|
||||
def get_functions_in_current_git_diff(_params: OptimizableFunctionsParams) -> dict[str, str | dict[str, list[str]]]:
|
||||
functions = get_functions_within_git_diff(uncommitted_changes=True)
|
||||
file_to_qualified_names = _group_functions_by_file(server, functions)
|
||||
file_to_qualified_names = _group_functions_by_file(functions)
|
||||
return {"functions": file_to_qualified_names, "status": "success"}
|
||||
|
||||
|
||||
@server.feature("getOptimizableFunctionsInCommit")
|
||||
def get_functions_in_commit(
|
||||
server: CodeflashLanguageServer, params: OptimizableFunctionsInCommitParams
|
||||
) -> dict[str, str | dict[str, list[str]]]:
|
||||
def get_functions_in_commit(params: OptimizableFunctionsInCommitParams) -> dict[str, str | dict[str, list[str]]]:
|
||||
functions = get_functions_inside_a_commit(params.commit_hash)
|
||||
file_to_qualified_names = _group_functions_by_file(server, functions)
|
||||
file_to_qualified_names = _group_functions_by_file(functions)
|
||||
return {"functions": file_to_qualified_names, "status": "success"}
|
||||
|
||||
|
||||
def _group_functions_by_file(
|
||||
server: CodeflashLanguageServer, functions: dict[str, list[FunctionToOptimize]]
|
||||
) -> dict[str, list[str]]:
|
||||
def _group_functions_by_file(functions: dict[str, list[FunctionToOptimize]]) -> dict[str, list[str]]:
|
||||
file_to_funcs_to_optimize, _ = filter_functions(
|
||||
modified_functions=functions,
|
||||
tests_root=server.optimizer.test_cfg.tests_root,
|
||||
|
|
@ -123,9 +120,7 @@ def _group_functions_by_file(
|
|||
|
||||
|
||||
@server.feature("getOptimizableFunctions")
|
||||
def get_optimizable_functions(
|
||||
server: CodeflashLanguageServer, params: OptimizableFunctionsParams
|
||||
) -> dict[str, list[str]]:
|
||||
def get_optimizable_functions(params: OptimizableFunctionsParams) -> dict[str, list[str]]:
|
||||
document_uri = params.textDocument.uri
|
||||
document = server.workspace.get_text_document(document_uri)
|
||||
|
||||
|
|
@ -172,7 +167,7 @@ def _find_pyproject_toml(workspace_path: str) -> tuple[Path | None, bool]:
|
|||
|
||||
|
||||
@server.feature("writeConfig")
|
||||
def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) -> dict[str, any]:
|
||||
def write_config(params: WriteConfigParams) -> dict[str, any]:
|
||||
cfg = params.config
|
||||
cfg_file = Path(params.config_file) if params.config_file else None
|
||||
|
||||
|
|
@ -196,7 +191,7 @@ def write_config(_server: CodeflashLanguageServer, params: WriteConfigParams) ->
|
|||
|
||||
|
||||
@server.feature("getConfigSuggestions")
|
||||
def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> dict[str, any]:
|
||||
def get_config_suggestions(_params: any) -> dict[str, any]:
|
||||
module_root_suggestions, default_module_root = get_suggestions(CommonSections.module_root)
|
||||
tests_root_suggestions, default_tests_root = get_suggestions(CommonSections.tests_root)
|
||||
test_framework_suggestions, default_test_framework = get_suggestions(CommonSections.test_framework)
|
||||
|
|
@ -212,9 +207,9 @@ def get_config_suggestions(_server: CodeflashLanguageServer, _params: any) -> di
|
|||
|
||||
# should be called the first thing to initialize and validate the project
|
||||
@server.feature("initProject")
|
||||
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]:
|
||||
def init_project(params: ValidateProjectParams) -> dict[str, str]:
|
||||
# Always process args in the init project, the extension can call
|
||||
server.args_processed_before = False
|
||||
server.initialized = False
|
||||
|
||||
pyproject_toml_path: Path | None = getattr(params, "config_file", None) or getattr(server.args, "config_file", None)
|
||||
if pyproject_toml_path is not None:
|
||||
|
|
@ -255,19 +250,16 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
|
|||
"existingConfig": config,
|
||||
}
|
||||
|
||||
args = process_args(server)
|
||||
|
||||
args = _init()
|
||||
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path, "root": root}
|
||||
|
||||
|
||||
def _initialize_optimizer_if_api_key_is_valid(
|
||||
server: CodeflashLanguageServer, api_key: Optional[str] = None
|
||||
) -> dict[str, str]:
|
||||
def _initialize_optimizer_if_api_key_is_valid(api_key: Optional[str] = None) -> dict[str, str]:
|
||||
key_check_result = _check_api_key_validity(api_key)
|
||||
if key_check_result.get("status") != "success":
|
||||
return key_check_result
|
||||
|
||||
_initialize_optimizer(server)
|
||||
_init()
|
||||
return key_check_result
|
||||
|
||||
|
||||
|
|
@ -283,48 +275,56 @@ def _check_api_key_validity(api_key: Optional[str]) -> dict[str, str]:
|
|||
return {"status": "success", "user_id": user_id}
|
||||
|
||||
|
||||
def _initialize_optimizer(server: CodeflashLanguageServer) -> None:
|
||||
def _initialize_optimizer(args: Namespace) -> None:
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
new_args = process_args(server)
|
||||
if not server.optimizer:
|
||||
server.optimizer = Optimizer(new_args)
|
||||
server.optimizer = Optimizer(args)
|
||||
|
||||
|
||||
def process_args(server: CodeflashLanguageServer) -> Namespace:
|
||||
if server.args_processed_before:
|
||||
return server.args
|
||||
def process_args() -> Namespace:
|
||||
new_args = process_pyproject_config(server.args)
|
||||
server.args = new_args
|
||||
server.args_processed_before = True
|
||||
return new_args
|
||||
|
||||
|
||||
def _init() -> Namespace:
|
||||
if server.initialized:
|
||||
return server.args
|
||||
new_args = process_args()
|
||||
_initialize_optimizer(new_args)
|
||||
server.initialized = True
|
||||
return new_args
|
||||
|
||||
|
||||
@server.feature("apiKeyExistsAndValid")
|
||||
def check_api_key(server: CodeflashLanguageServer, _params: any) -> dict[str, str]:
|
||||
def check_api_key(_params: any) -> dict[str, str]:
|
||||
try:
|
||||
return _initialize_optimizer_if_api_key_is_valid(server)
|
||||
return _initialize_optimizer_if_api_key_is_valid()
|
||||
except Exception:
|
||||
return {"status": "error", "message": "something went wrong while validating the api key"}
|
||||
|
||||
|
||||
@server.feature("provideApiKey")
|
||||
def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams) -> dict[str, str]:
|
||||
def provide_api_key(params: ProvideApiKeyParams) -> dict[str, str]:
|
||||
try:
|
||||
api_key = params.api_key
|
||||
if not api_key.startswith("cf-"):
|
||||
return {"status": "error", "message": "Api key is not valid"}
|
||||
|
||||
# # clear cache to ensure the new api key is used
|
||||
# clear cache to ensure the new api key is used
|
||||
get_codeflash_api_key.cache_clear()
|
||||
get_user_id.cache_clear()
|
||||
|
||||
key_check_result = _check_api_key_validity(api_key)
|
||||
if key_check_result.get("status") != "success":
|
||||
return key_check_result
|
||||
|
||||
user_id = key_check_result["user_id"]
|
||||
result = save_api_key_to_rc(api_key)
|
||||
|
||||
# initialize optimizer with the new api key
|
||||
_initialize_optimizer(server)
|
||||
_init()
|
||||
if not is_successful(result):
|
||||
return {"status": "error", "message": result.failure()}
|
||||
return {"status": "success", "message": "Api key saved successfully", "user_id": user_id} # noqa: TRY300
|
||||
|
|
@ -332,85 +332,106 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
|
|||
return {"status": "error", "message": "something went wrong while saving the api key"}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def execution_context(**kwargs: str) -> None:
|
||||
"""Temporarily set context values for the current async task."""
|
||||
# Create a fresh copy per use
|
||||
current = {**server.execution_context_vars.get(), **kwargs}
|
||||
token = server.execution_context_vars.set(current)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
server.execution_context_vars.reset(token)
|
||||
|
||||
|
||||
@server.feature("cleanupCurrentOptimizerSession")
|
||||
def cleanup_optimizer(_params: any) -> dict[str, str]:
|
||||
if not server.cleanup_the_optimizer():
|
||||
return {"status": "error", "message": "Failed to cleanup optimizer"}
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@server.feature("initializeFunctionOptimization")
|
||||
def initialize_function_optimization(
|
||||
server: CodeflashLanguageServer, params: FunctionOptimizationInitParams
|
||||
) -> dict[str, str]:
|
||||
document_uri = params.textDocument.uri
|
||||
document = server.workspace.get_text_document(document_uri)
|
||||
file_path = Path(document.path)
|
||||
def initialize_function_optimization(params: FunctionOptimizationInitParams) -> dict[str, str]:
|
||||
with execution_context(task_id=getattr(params, "task_id", None)):
|
||||
document_uri = params.textDocument.uri
|
||||
document = server.workspace.get_text_document(document_uri)
|
||||
file_path = Path(document.path)
|
||||
|
||||
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info")
|
||||
server.show_message_log(
|
||||
f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info"
|
||||
)
|
||||
|
||||
if server.optimizer is None:
|
||||
_initialize_optimizer_if_api_key_is_valid(server)
|
||||
if server.optimizer is None:
|
||||
_initialize_optimizer_if_api_key_is_valid()
|
||||
|
||||
server.optimizer.args.file = file_path
|
||||
server.optimizer.args.function = params.functionName
|
||||
server.optimizer.args.previous_checkpoint_functions = False
|
||||
server.optimizer.args.file = file_path
|
||||
server.optimizer.args.function = params.functionName
|
||||
server.optimizer.args.previous_checkpoint_functions = False
|
||||
|
||||
server.optimizer.worktree_mode()
|
||||
server.optimizer.worktree_mode()
|
||||
|
||||
server.show_message_log(
|
||||
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
|
||||
)
|
||||
server.show_message_log(
|
||||
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
|
||||
)
|
||||
|
||||
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
|
||||
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
|
||||
|
||||
if count == 0:
|
||||
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
|
||||
server.cleanup_the_optimizer()
|
||||
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
|
||||
if count == 0:
|
||||
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
|
||||
server.cleanup_the_optimizer()
|
||||
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
|
||||
|
||||
fto = optimizable_funcs.popitem()[1][0]
|
||||
fto = optimizable_funcs.popitem()[1][0]
|
||||
|
||||
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
|
||||
if not module_prep_result:
|
||||
return {
|
||||
"functionName": params.functionName,
|
||||
"status": "error",
|
||||
"message": "Failed to prepare module for optimization",
|
||||
}
|
||||
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
|
||||
if not module_prep_result:
|
||||
return {
|
||||
"functionName": params.functionName,
|
||||
"status": "error",
|
||||
"message": "Failed to prepare module for optimization",
|
||||
}
|
||||
|
||||
validated_original_code, original_module_ast = module_prep_result
|
||||
validated_original_code, original_module_ast = module_prep_result
|
||||
|
||||
function_optimizer = server.optimizer.create_function_optimizer(
|
||||
fto,
|
||||
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
|
||||
original_module_ast=original_module_ast,
|
||||
original_module_path=fto.file_path,
|
||||
function_to_tests={},
|
||||
)
|
||||
function_optimizer = server.optimizer.create_function_optimizer(
|
||||
fto,
|
||||
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
|
||||
original_module_ast=original_module_ast,
|
||||
original_module_path=fto.file_path,
|
||||
function_to_tests={},
|
||||
)
|
||||
|
||||
server.optimizer.current_function_optimizer = function_optimizer
|
||||
if not function_optimizer:
|
||||
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
|
||||
server.optimizer.current_function_optimizer = function_optimizer
|
||||
if not function_optimizer:
|
||||
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
|
||||
|
||||
initialization_result = function_optimizer.can_be_optimized()
|
||||
if not is_successful(initialization_result):
|
||||
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
|
||||
initialization_result = function_optimizer.can_be_optimized()
|
||||
if not is_successful(initialization_result):
|
||||
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
|
||||
|
||||
server.current_optimization_init_result = initialization_result.unwrap()
|
||||
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
|
||||
server.current_optimization_init_result = initialization_result.unwrap()
|
||||
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
|
||||
|
||||
files = [function_optimizer.function_to_optimize.file_path]
|
||||
files = [function_optimizer.function_to_optimize.file_path]
|
||||
|
||||
_, _, original_helpers = server.current_optimization_init_result
|
||||
files.extend([str(helper_path) for helper_path in original_helpers])
|
||||
_, _, original_helpers = server.current_optimization_init_result
|
||||
files.extend([str(helper_path) for helper_path in original_helpers])
|
||||
|
||||
return {"functionName": params.functionName, "status": "success", "files_inside_context": files}
|
||||
return {"functionName": params.functionName, "status": "success", "files_inside_context": files}
|
||||
|
||||
|
||||
@server.feature("performFunctionOptimization")
|
||||
async def perform_function_optimization(
|
||||
server: CodeflashLanguageServer, params: FunctionOptimizationParams
|
||||
) -> dict[str, str]:
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
result = await loop.run_in_executor(None, sync_perform_optimization, server, params)
|
||||
except asyncio.CancelledError:
|
||||
return {"status": "canceled", "message": "Task was canceled"}
|
||||
else:
|
||||
return result
|
||||
finally:
|
||||
server.cleanup_the_optimizer()
|
||||
async def perform_function_optimization(params: FunctionOptimizationParams) -> dict[str, str]:
|
||||
with execution_context(task_id=getattr(params, "task_id", None)):
|
||||
loop = asyncio.get_running_loop()
|
||||
cancel_event = threading.Event()
|
||||
|
||||
try:
|
||||
ctx = contextvars.copy_context()
|
||||
return await loop.run_in_executor(None, ctx.run, sync_perform_optimization, server, cancel_event, params)
|
||||
except asyncio.CancelledError:
|
||||
cancel_event.set()
|
||||
return get_cancelled_reponse()
|
||||
finally:
|
||||
server.cleanup_the_optimizer()
|
||||
|
|
|
|||
|
|
@ -1,13 +1,29 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import code_print
|
||||
from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.lsp.server import CodeflashLanguageServer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import threading
|
||||
|
||||
from codeflash.lsp.server import CodeflashLanguageServer
|
||||
|
||||
|
||||
def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[str, str]: # noqa: ANN001
|
||||
def get_cancelled_reponse() -> dict[str, str]:
|
||||
return {"status": "canceled", "message": "Task was canceled"}
|
||||
|
||||
|
||||
def abort_if_cancelled(cancel_event: threading.Event) -> None:
|
||||
if cancel_event.is_set():
|
||||
raise RuntimeError("cancelled")
|
||||
|
||||
|
||||
def sync_perform_optimization(server: CodeflashLanguageServer, cancel_event: threading.Event, params) -> dict[str, str]: # noqa
|
||||
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
|
||||
should_run_experiment, code_context, original_helper_code = server.current_optimization_init_result
|
||||
function_optimizer = server.optimizer.current_function_optimizer
|
||||
|
|
@ -18,6 +34,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[s
|
|||
file_name=current_function.file_path,
|
||||
function_name=current_function.function_name,
|
||||
)
|
||||
abort_if_cancelled(cancel_event)
|
||||
|
||||
optimizable_funcs = {current_function.file_path: [current_function]}
|
||||
|
||||
|
|
@ -26,9 +43,11 @@ def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[s
|
|||
function_to_tests, _num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
|
||||
function_optimizer.function_to_tests = function_to_tests
|
||||
|
||||
abort_if_cancelled(cancel_event)
|
||||
test_setup_result = function_optimizer.generate_and_instrument_tests(
|
||||
code_context, should_run_experiment=should_run_experiment
|
||||
)
|
||||
abort_if_cancelled(cancel_event)
|
||||
if not is_successful(test_setup_result):
|
||||
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
|
||||
(
|
||||
|
|
@ -52,6 +71,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[s
|
|||
original_conftest_content=original_conftest_content,
|
||||
)
|
||||
|
||||
abort_if_cancelled(cancel_event)
|
||||
if not is_successful(baseline_setup_result):
|
||||
return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()}
|
||||
|
||||
|
|
@ -76,6 +96,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[s
|
|||
concolic_test_str=concolic_test_str,
|
||||
)
|
||||
|
||||
abort_if_cancelled(cancel_event)
|
||||
if not best_optimization:
|
||||
server.show_message_log(
|
||||
f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning"
|
||||
|
|
@ -93,6 +114,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, params) -> dict[s
|
|||
server.optimizer.current_worktree, relative_file_paths, function_to_optimize_qualified_name
|
||||
)
|
||||
|
||||
abort_if_cancelled(cancel_event)
|
||||
if not patch_path:
|
||||
return {
|
||||
"functionName": params.functionName,
|
||||
|
|
|
|||
|
|
@ -3,13 +3,15 @@ from __future__ import annotations
|
|||
import logging
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.lsp.lsp_message import LspTextMessage, message_delimiter
|
||||
from codeflash.lsp.lsp_message import LSPMessageId, LspTextMessage, message_delimiter
|
||||
|
||||
root_logger = None
|
||||
|
||||
message_id_prefix = "id:"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LspMessageTags:
|
||||
|
|
@ -18,6 +20,7 @@ class LspMessageTags:
|
|||
lsp: bool = False # lsp (lsp only)
|
||||
force_lsp: bool = False # force_lsp (you can use this to force a message to be sent to the LSP even if the level is not supported)
|
||||
loading: bool = False # loading (you can use this to indicate that the message is a loading message)
|
||||
message_id: Optional[LSPMessageId] = None # example: id:best_candidate
|
||||
highlight: bool = False # highlight (you can use this to highlight the message by wrapping it in ``)
|
||||
h1: bool = False # h1
|
||||
h2: bool = False # h2
|
||||
|
|
@ -52,24 +55,27 @@ def extract_tags(msg: str) -> tuple[LspMessageTags, str]:
|
|||
tags = {tag.strip() for tag in tags_str.split(",")}
|
||||
message_tags = LspMessageTags()
|
||||
# manually check and set to avoid repeated membership tests
|
||||
if "lsp" in tags:
|
||||
message_tags.lsp = True
|
||||
if "!lsp" in tags:
|
||||
message_tags.not_lsp = True
|
||||
if "force_lsp" in tags:
|
||||
message_tags.force_lsp = True
|
||||
if "loading" in tags:
|
||||
message_tags.loading = True
|
||||
if "highlight" in tags:
|
||||
message_tags.highlight = True
|
||||
if "h1" in tags:
|
||||
message_tags.h1 = True
|
||||
if "h2" in tags:
|
||||
message_tags.h2 = True
|
||||
if "h3" in tags:
|
||||
message_tags.h3 = True
|
||||
if "h4" in tags:
|
||||
message_tags.h4 = True
|
||||
for tag in tags:
|
||||
if tag.startswith(message_id_prefix):
|
||||
message_tags.message_id = LSPMessageId(tag[len(message_id_prefix) :]).value
|
||||
elif tag == "lsp":
|
||||
message_tags.lsp = True
|
||||
elif tag == "!lsp":
|
||||
message_tags.not_lsp = True
|
||||
elif tag == "force_lsp":
|
||||
message_tags.force_lsp = True
|
||||
elif tag == "loading":
|
||||
message_tags.loading = True
|
||||
elif tag == "highlight":
|
||||
message_tags.highlight = True
|
||||
elif tag == "h1":
|
||||
message_tags.h1 = True
|
||||
elif tag == "h2":
|
||||
message_tags.h2 = True
|
||||
elif tag == "h3":
|
||||
message_tags.h3 = True
|
||||
elif tag == "h4":
|
||||
message_tags.h4 = True
|
||||
return message_tags, content
|
||||
|
||||
return LspMessageTags(), msg
|
||||
|
|
@ -110,11 +116,15 @@ def enhanced_log(
|
|||
actual_log_fn(clean_msg, *args, **kwargs)
|
||||
return
|
||||
|
||||
if not lsp_enabled:
|
||||
# it's for LSP and LSP is disabled
|
||||
return
|
||||
|
||||
# ---- LSP logging path ----
|
||||
if is_normal_text_message:
|
||||
clean_msg = add_heading_tags(clean_msg, tags)
|
||||
clean_msg = add_highlight_tags(clean_msg, tags)
|
||||
clean_msg = LspTextMessage(text=clean_msg, takes_time=tags.loading).serialize()
|
||||
clean_msg = LspTextMessage(text=clean_msg, takes_time=tags.loading, message_id=tags.message_id).serialize()
|
||||
|
||||
actual_log_fn(clean_msg, *args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from codeflash.lsp.helpers import replace_quotes_with_backticks, simplify_worktree_paths
|
||||
from codeflash.lsp.helpers import is_LSP_enabled, replace_quotes_with_backticks, simplify_worktree_paths
|
||||
|
||||
json_primitive_types = (str, float, int, bool)
|
||||
max_code_lines_before_collapse = 45
|
||||
|
|
@ -14,10 +15,17 @@ max_code_lines_before_collapse = 45
|
|||
message_delimiter = "\u241f"
|
||||
|
||||
|
||||
# allow the client to know which message it is receiving
|
||||
class LSPMessageId(enum.Enum):
|
||||
BEST_CANDIDATE = "best_candidate"
|
||||
CANDIDATE = "candidate"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LspMessage:
|
||||
# to show a loading indicator if the operation is taking time like generating candidates or tests
|
||||
takes_time: bool = False
|
||||
message_id: Optional[str] = None
|
||||
|
||||
def _loop_through(self, obj: Any) -> Any: # noqa: ANN401
|
||||
if isinstance(obj, list):
|
||||
|
|
@ -34,8 +42,14 @@ class LspMessage:
|
|||
raise NotImplementedError
|
||||
|
||||
def serialize(self) -> str:
|
||||
if not is_LSP_enabled():
|
||||
return ""
|
||||
from codeflash.lsp.beta import server
|
||||
|
||||
execution_ctx = server.execution_context_vars.get()
|
||||
current_task_id = execution_ctx.get("task_id", None)
|
||||
data = self._loop_through(asdict(self))
|
||||
ordered = {"type": self.type(), **data}
|
||||
ordered = {"type": self.type(), "task_id": current_task_id, **data}
|
||||
return message_delimiter + json.dumps(ordered) + message_delimiter
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import contextvars
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lsprotocol.types import LogMessageParams, MessageType
|
||||
from pygls.lsp.server import LanguageServer
|
||||
|
|
@ -18,12 +19,16 @@ class CodeflashLanguageServerProtocol(LanguageServerProtocol):
|
|||
|
||||
|
||||
class CodeflashLanguageServer(LanguageServer):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
|
||||
super().__init__(*args, **kwargs)
|
||||
def __init__(self, name: str, version: str, protocol_cls: type[LanguageServerProtocol]) -> None:
|
||||
super().__init__(name, version, protocol_cls=protocol_cls)
|
||||
self.initialized: bool = False
|
||||
self.optimizer: Optimizer | None = None
|
||||
self.args_processed_before: bool = False
|
||||
self.args = None
|
||||
self.current_optimization_init_result: tuple[bool, CodeOptimizationContext, dict[Path, str]] | None = None
|
||||
self.execution_context_vars: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar(
|
||||
"execution_context_vars",
|
||||
default={}, # noqa: B039
|
||||
)
|
||||
|
||||
def prepare_optimizer_arguments(self, config_file: Path) -> None:
|
||||
from codeflash.cli_cmds.cli import parse_args
|
||||
|
|
@ -58,10 +63,10 @@ class CodeflashLanguageServer(LanguageServer):
|
|||
log_params = LogMessageParams(type=lsp_message_type, message=message)
|
||||
self.protocol.notify("window/logMessage", log_params)
|
||||
|
||||
def cleanup_the_optimizer(self) -> None:
|
||||
def cleanup_the_optimizer(self) -> bool:
|
||||
self.current_optimization_init_result = None
|
||||
if not self.optimizer:
|
||||
return
|
||||
return False
|
||||
try:
|
||||
self.optimizer.cleanup_temporary_paths()
|
||||
# restore args and test cfg
|
||||
|
|
@ -72,6 +77,8 @@ class CodeflashLanguageServer(LanguageServer):
|
|||
self.optimizer.current_function_optimizer = None
|
||||
except Exception:
|
||||
self.show_message_log("Failed to cleanup optimizer", "Error")
|
||||
return False
|
||||
return True
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Gracefully shutdown the server."""
|
||||
|
|
|
|||
|
|
@ -30,10 +30,10 @@ from codeflash.code_utils.code_replacer import (
|
|||
replace_function_definitions_in_module,
|
||||
)
|
||||
from codeflash.code_utils.code_utils import (
|
||||
ImportErrorPattern,
|
||||
cleanup_paths,
|
||||
create_rank_dictionary_compact,
|
||||
diff_length,
|
||||
extract_unique_errors,
|
||||
file_name_from_test_module_name,
|
||||
get_run_tmp_file,
|
||||
module_name_from_file_path,
|
||||
|
|
@ -66,7 +66,7 @@ from codeflash.context.unused_definition_remover import detect_unused_helper_fun
|
|||
from codeflash.discovery.functions_to_optimize import was_function_previously_optimized
|
||||
from codeflash.either import Failure, Success, is_successful
|
||||
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown
|
||||
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage
|
||||
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId
|
||||
from codeflash.models.ExperimentMetadata import ExperimentMetadata
|
||||
from codeflash.models.models import (
|
||||
BestOptimization,
|
||||
|
|
@ -532,7 +532,11 @@ class FunctionOptimizer:
|
|||
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
|
||||
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
|
||||
logger.info(f"h3|Optimization candidate {candidate_index}/{processor.candidate_len}:")
|
||||
code_print(candidate.source_code.flat, file_name=f"candidate_{candidate_index}.py")
|
||||
code_print(
|
||||
candidate.source_code.flat,
|
||||
file_name=f"candidate_{candidate_index}.py",
|
||||
lsp_message_id=LSPMessageId.CANDIDATE.value,
|
||||
)
|
||||
# map ast normalized code to diff len, unnormalized code
|
||||
# map opt id to the shortest unnormalized code
|
||||
try:
|
||||
|
|
@ -1344,6 +1348,7 @@ class FunctionOptimizer:
|
|||
best_optimization.candidate.source_code.flat,
|
||||
file_name="best_candidate.py",
|
||||
function_name=self.function_to_optimize.function_name,
|
||||
lsp_message_id=LSPMessageId.BEST_CANDIDATE.value,
|
||||
)
|
||||
processed_benchmark_info = None
|
||||
if self.args.benchmark:
|
||||
|
|
@ -1629,17 +1634,20 @@ class FunctionOptimizer:
|
|||
)
|
||||
if not behavioral_results:
|
||||
logger.warning(
|
||||
f"force_lsp|Couldn't run any tests for original function {self.function_to_optimize.function_name}. SKIPPING OPTIMIZING THIS FUNCTION."
|
||||
f"force_lsp|Couldn't run any tests for original function {self.function_to_optimize.function_name}. Skipping optimization."
|
||||
)
|
||||
console.rule()
|
||||
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
|
||||
if not coverage_critic(coverage_results, self.args.test_framework):
|
||||
did_pass_all_tests = all(result.did_pass for result in behavioral_results)
|
||||
if not did_pass_all_tests:
|
||||
return Failure("Tests failed to pass for the original code.")
|
||||
return Failure(
|
||||
f"Test coverage is {coverage_results.coverage}%, which is below the required threshold of {COVERAGE_THRESHOLD}%."
|
||||
)
|
||||
|
||||
if test_framework == "pytest":
|
||||
with progress_bar("Running line profiling to identify performance bottlenecks..."):
|
||||
with progress_bar("Running line profiler to identify performance bottlenecks..."):
|
||||
line_profile_results = self.line_profiler_step(
|
||||
code_context=code_context, original_helper_code=original_helper_code, candidate_index=0
|
||||
)
|
||||
|
|
@ -1939,7 +1947,7 @@ class FunctionOptimizer:
|
|||
*,
|
||||
enable_coverage: bool = False,
|
||||
pytest_min_loops: int = 5,
|
||||
pytest_max_loops: int = 100_000,
|
||||
pytest_max_loops: int = 250,
|
||||
code_context: CodeOptimizationContext | None = None,
|
||||
unittest_loop_index: int | None = None,
|
||||
line_profiler_output_file: Path | None = None,
|
||||
|
|
@ -1997,12 +2005,19 @@ class FunctionOptimizer:
|
|||
f"stdout: {run_result.stdout}\n"
|
||||
f"stderr: {run_result.stderr}\n"
|
||||
)
|
||||
if "ModuleNotFoundError" in run_result.stdout:
|
||||
|
||||
unique_errors = extract_unique_errors(run_result.stdout)
|
||||
|
||||
if unique_errors:
|
||||
from rich.text import Text
|
||||
|
||||
match = ImportErrorPattern.search(run_result.stdout).group()
|
||||
panel = Panel(Text.from_markup(f"⚠️ {match} ", style="bold red"), expand=False)
|
||||
console.print(panel)
|
||||
for error in unique_errors:
|
||||
if is_LSP_enabled():
|
||||
lsp_log(LspCodeMessage(code=error, file_name="errors"))
|
||||
else:
|
||||
panel = Panel(Text.from_markup(f"⚠️ {error} ", style="bold red"), expand=False)
|
||||
console.print(panel)
|
||||
|
||||
if testing_type in {TestingMode.BEHAVIOR, TestingMode.PERFORMANCE}:
|
||||
results, coverage_results = parse_test_results(
|
||||
test_xml_path=result_file_path,
|
||||
|
|
@ -2101,6 +2116,13 @@ class FunctionOptimizer:
|
|||
self.write_code_and_helpers(
|
||||
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
# this will happen when a timeoutexpired exception happens
|
||||
if isinstance(line_profile_results, TestResults) and not line_profile_results.test_results:
|
||||
logger.warning(
|
||||
f"Timeout occurred while running line profiler for original function {self.function_to_optimize.function_name}"
|
||||
)
|
||||
# set default value for line profiler results
|
||||
return {"timings": {}, "unit": 0, "str_out": ""}
|
||||
if line_profile_results["str_out"] == "":
|
||||
logger.warning(
|
||||
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
|
||||
|
|
|
|||
|
|
@ -98,8 +98,8 @@ def _apply_deterministic_patches() -> None:
|
|||
_original_random = random.random
|
||||
|
||||
# Fixed deterministic values
|
||||
fixed_timestamp = 1609459200.0 # 2021-01-01 00:00:00 UTC
|
||||
fixed_datetime = datetime.datetime(2021, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
|
||||
fixed_timestamp = 1761717605.108106
|
||||
fixed_datetime = datetime.datetime(2021, 1, 1, 2, 5, 10, tzinfo=datetime.timezone.utc)
|
||||
fixed_uuid = uuid.UUID("12345678-1234-5678-9abc-123456789012")
|
||||
|
||||
# Counter for perf_counter to maintain relative timing
|
||||
|
|
|
|||
|
|
@ -48,8 +48,8 @@ class TestDeterministicPatches:
|
|||
original_os_urandom = os.urandom
|
||||
|
||||
# Create deterministic implementations (matching pytest_plugin.py)
|
||||
fixed_timestamp = 1609459200.0 # 2021-01-01 00:00:00 UTC
|
||||
fixed_datetime = datetime.datetime(2021, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
|
||||
fixed_timestamp = 1761717605.108106
|
||||
fixed_datetime = datetime.datetime(2021, 1, 1, 2, 5, 10, tzinfo=datetime.timezone.utc)
|
||||
fixed_uuid = uuid.UUID("12345678-1234-5678-9abc-123456789012")
|
||||
|
||||
# Counter for perf_counter
|
||||
|
|
@ -159,7 +159,7 @@ class TestDeterministicPatches:
|
|||
|
||||
def test_time_time_deterministic(self, setup_deterministic_environment):
|
||||
"""Test that time.time() returns a fixed deterministic value."""
|
||||
expected_timestamp = 1609459200.0 # 2021-01-01 00:00:00 UTC
|
||||
expected_timestamp = 1761717605.108106
|
||||
|
||||
# Call multiple times and verify consistent results
|
||||
result1 = time.time()
|
||||
|
|
@ -311,7 +311,7 @@ class TestDeterministicPatches:
|
|||
result1 = mock_now()
|
||||
result2 = mock_utcnow()
|
||||
|
||||
expected_dt = datetime.datetime(2021, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)
|
||||
expected_dt = datetime.datetime(2021, 1, 1, 2, 5, 10, tzinfo=datetime.timezone.utc)
|
||||
assert result1 == expected_dt
|
||||
assert result2 == expected_dt
|
||||
|
||||
|
|
@ -355,7 +355,7 @@ class TestDeterministicPatches:
|
|||
def test_patches_applied_correctly(self, setup_deterministic_environment):
|
||||
"""Test that patches are applied correctly."""
|
||||
# Test that functions return expected deterministic values
|
||||
assert time.time() == 1609459200.0
|
||||
assert time.time() == 1761717605.108106
|
||||
assert uuid.uuid4() == uuid.UUID("12345678-1234-5678-9abc-123456789012")
|
||||
assert random.random() == 0.123456789
|
||||
assert os.urandom(4) == b"\x42\x42\x42\x42"
|
||||
|
|
@ -378,7 +378,7 @@ class TestDeterministicPatches:
|
|||
# Test with different timezone
|
||||
utc_tz = datetime.timezone.utc
|
||||
result_with_tz = mock_now(utc_tz)
|
||||
expected_with_tz = datetime.datetime(2021, 1, 1, 0, 0, 0, tzinfo=utc_tz)
|
||||
expected_with_tz = datetime.datetime(2021, 1, 1, 2, 5, 10, tzinfo=datetime.timezone.utc)
|
||||
assert result_with_tz == expected_with_tz
|
||||
|
||||
def test_integration_with_actual_optimization_scenario(self, setup_deterministic_environment):
|
||||
|
|
|
|||
Loading…
Reference in a new issue