from __future__ import annotations import asyncio import logging from pathlib import Path from typing import TYPE_CHECKING import libcst as cst import sentry_sdk from aiservice.analytics.posthog import ph from aiservice.common_utils import validate_trace_id from aiservice.env_specific import create_llm_client, debug_log_sensitive_data, llm_clients from aiservice.models.aimodels import REFINEMENT_MODEL, calculate_llm_cost from log_features.log_event import update_optimization_cost from log_features.log_features import log_features from ninja import NinjaAPI, Schema from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam from pydantic import ValidationError from optimizer.context_utils.refiner_context import BaseRefinerContext, RefinementContextData if TYPE_CHECKING: from aiservice.models.aimodels import LLM from openai.types.chat import ( ChatCompletionAssistantMessageParam, ChatCompletionFunctionMessageParam, ChatCompletionToolMessageParam, ) refinement_api = NinjaAPI(urls_namespace="refinement") # Get the directory of the current file current_dir = Path(__file__).parent SYSTEM_PROMPT = """You are an expert software engineer who writes really fast programs and is an expert in optimizing runtime and memory requirements of a program by rewriting it. You have deep expertise in modern programming best practices, and clean code principles. You are part of a code optimization system called Codeflash. Codeflash analyzes user code to optimize it for performance. It does so by first establishing the original code baseline by generating regression tests and discovering existing test cases, which finds the runtime of the original code and the behavior. Then, Codeflash generates multiple candidate optimizations which are applied to the codebase and ran to find the new runtime and behavior. Codeflash discards any optimizations that are either slower or have different behavior, in order to find the real optimizations. Even though the optimizations may be technically correct, we want the quality of the optimizations to be really high, so your task is to refine the code of the optimizations so that they are expert quality. The goal of refining the quality is to make the optimizations more precise. You want to reduce the number of lines and characters different between the optimizations and the original code to deliver very similar optimizations. These precise optimizations are highly preferred by the user and makes it is easier to accept the optimizations. The refinement process should NEVER change the behavior of the optimization and we should try to preserve the main optimizations. You are provided the following information to succeed in the quality refinement process - - original_source_code - This is the original code being optimized - optimized_source_code - This is the previously found optimization source code. - original_line_profiler_results - The results after running line_profiler on the original_source_code - optimized_line_profiler_results - The results after running line_profiler on the optimized_source_code - optimization_speedup_results - The runtime for the original_source_code and optimized_source_code over a series of tests - optimization_explanation - The explanation generated by codeflash earlier for the optimization in the optimized_source_code - read_only_dependency_code - The READ ONLY dependencies for the code provided, to help you better understand the code being provided. Do no modify the code here, it is only provided for your reference. - python_version - The version of python the code would be executed on. - function_references - Python markdown blocks with filename and references of some functions which call the function being optimized. The filenames and/or references could indicate if the function being optimized is in a hot path. The reference could have the function being called from a place that is important, for example in a loop, which means the effect of optimization might be important. Rules to follow while refining the quality of the optimized code - - Analyze the original code and the optimized code and look at the line profiler info and the explanation to understand how the optimization works - Introduction of the `global` and `nonlocal` keywords in optimized_source_code is **HIGHLY DISCOURAGED** as it reduces code clarity and maintainability, introduces hidden dependencies, can cause subtle bugs and breaks modularity. Revert any such changes. - If there are micro-optimizations like inlining a function call, or localizing variables or methods (not being used in a loop), especially if python_version is older than 3.11, revert any such changes. - Figure out the code difference between the original_source_code and the optimized_source_code to see what part of the optimized_source_code is not contributing to the optimization. In such a case, we want to revert that part of the optimized_source_code to the original_source_code. It is okay to revert parts of the changes that aren't faster by at least 1%. - If there are any changes in the optimized code that make that code section slower than the original then we want to revert such a change to the original. - Revert the new comments in the optimized_source_code that are different from the original_source_code unless the new code is complex and requires additional context. - The the variables names for the same logical variable in the optimized_source_code is different from the original_source_code then prefer the variable name in the original_source_code. Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file. TOOL USE You have access to a set of tools that are don't need any approval to run and is required to use to succeed with the task. Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure: value1 value2 ... For example: src/main.py <<<<<<< SEARCH a = 2 ======= a = 3 >>>>>>> REPLACE Always adhere to this format for tool use to ensure proper parsing and execution. # Tools ## replace_in_file Description: Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file. This tool should be used when you need to make targeted changes to specific parts of a file. Parameters: - path: (required) The path of the file to modify - diff: (required) One or more SEARCH/REPLACE blocks following this exact format: ``` <<<<<<< SEARCH [exact content to find] ======= [new content to replace with] >>>>>>> REPLACE ``` Critical rules: 1. SEARCH content must match the associated file section to find EXACTLY: * Match character-for-character including whitespace, indentation, line endings * Include all comments, docstrings, etc. 2. SEARCH/REPLACE blocks will ONLY replace the first match occurrence. * Including multiple unique SEARCH/REPLACE blocks if you need to make multiple changes. * Include *just* enough lines in each SEARCH section to uniquely match each set of lines that need to change. * When using multiple SEARCH/REPLACE blocks, list them in the order they appear in the file. 3. Keep SEARCH/REPLACE blocks concise: * Break large SEARCH/REPLACE blocks into a series of smaller blocks that each change a small portion of the file. * Include just the changing lines, and a few surrounding lines if needed for uniqueness. * Do not include long runs of unchanging lines in SEARCH/REPLACE blocks. * Each line must be complete. Never truncate lines mid-way through as this can cause matching failures. 4. Special operations: * To move code: Use two SEARCH/REPLACE blocks (one to delete from original + one to insert at new location) * To delete code: Use empty REPLACE section Usage: File path here Search and replace blocks here Please edit the optimized_source_code to implement the refinement process to improve the quality of the optimization. """ USER_PROMPT = """ Please edit the optimized_source_code to implement the refinement process to improve the quality of the optimization given the following information. {original_source_code} Here is the line profiler information for the original_source_code {original_line_profiler_results} Here is the optimized_source_code {optimized_source_code} Here is the explanation generated by Codeflash for the optimized_source_code - {optimized_explanation} Here is the line profiler information for the optimized_source_code {optimized_line_profiler_results} The original_source_code takes {original_code_runtime} to run and the optimized_source_code takes {optimized_code_runtime} to run, making it {speedup} faster. Here is the read_only_dependency_code {read_only_dependency_code} Here is the python version {python_version} Here is the function_references {function_references} """ async def refinement( # noqa: D417 user_id: str, optimization_id: str, ctx: BaseRefinerContext, optimize_model: LLM = REFINEMENT_MODEL ) -> RefinementIntermediateResponseItemschema | OptimizeErrorResponseSchema: """Optimize the given python code for performance using Anthropic's Claude 4 model. Parameters ---------- :param user_id: :param optimization_id: the optimization id of the original candidate :param optimize_model: LLM for getting the refinements :param ctx: the refiner context (single or multi), has the data property which includes - speedup: original speedup for candidate - optimized_code_runtime: runtime in ns for candidate - optimization_id: id of optimization candidate - optimized_line_profiler_results: line profiler results for optimized candidate - optimized_source_code: source code of optimization candidate - read_only_dependency_code: unmodifiable code for better context - original_code_runtime: runtime in ns for baseline code - original_line_profiler_results: line profiler results for baseline code - original_source_code: baseline source code - optimized_explanation: Original explanation of the candidate Returns ------- RefinementIntermediateResponseItemschema optimization_id: str = "" source_code explanation """ system_prompt = ctx.get_system_prompt() user_prompt = ctx.get_user_prompt() system_message = ChatCompletionSystemMessageParam(role="system", content=system_prompt) user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt) messages: list[ ChatCompletionSystemMessageParam | ChatCompletionUserMessageParam | ChatCompletionAssistantMessageParam | ChatCompletionToolMessageParam | ChatCompletionFunctionMessageParam ] = [system_message, user_message] debug_log_sensitive_data(f"This was the user prompt\n {user_prompt}\n") llm_client = llm_clients[optimize_model.model_type] try: output = await llm_client.with_options(max_retries=2).chat.completions.create( model=optimize_model.name, messages=messages, n=1 ) llm_cost = calculate_llm_cost(output, optimize_model) except Exception as e: logging.exception("Claude Code Generation error in refinement") sentry_sdk.capture_exception(e) debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.data.original_source_code}") return OptimizeErrorResponseSchema(error=str(e)) debug_log_sensitive_data(f"ClaudeClient optimization response:\n{output.model_dump_json(indent=2)}") if output.usage is not None: ph(user_id, "refinement-usage", properties={"model": optimize_model.name, "usage": output.usage.json()}) results = [content for op in output.choices if (content := op.message.content)] # will be of size 1 # Regex doesn't work yet in extracting everything else other than the search replace block refined_explanation = results[0] refined_optimization = "" try: diff_patches = ctx.extract_diff_patches_from_llm_res(results[0]) refined_optimization = ctx.apply_patches_to_optimized_code(diff_patches) except (ValueError, ValidationError) as exc: sentry_sdk.capture_exception(exc) debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{ctx.data.optimized_source_code}") debug_log_sensitive_data(f"Traceback: {exc}") refined_optimization = "" if not ctx.is_valid_refinement(refined_optimization): refined_optimization = "" return RefinementIntermediateResponseItemschema( optimization_id=optimization_id, source_code=refined_optimization, explanation=refined_explanation, original_explanation=ctx.data.optimized_explanation, llm_cost=llm_cost, ) class RefinementRequestSchema(Schema): trace_id: str optimization_id: str original_source_code: str original_line_profiler_results: str = "" read_only_dependency_code: str optimized_source_code: str optimized_line_profiler_results: str = "" optimized_explanation: str original_code_runtime: str = "" optimized_code_runtime: str = "" speedup: str = "" python_version: str | None = None function_references: str | None = None class OptimizeErrorResponseSchema(Schema): error: str class RefinementIntermediateResponseItemschema(Schema): # the key will be the optimization id and the value will be the actual refined code explanation: str optimization_id: str source_code: str original_explanation: str llm_cost: float class RefinementResponseItemschema(Schema): # the key will be the optimization id and the value will be the actual refined code explanation: str optimization_id: str source_code: str class Refinementschema(Schema): # the key will be the optimization id and the value will be the actual refined code refinements: list[RefinementResponseItemschema] @refinement_api.post("/", response={200: Refinementschema, 400: Refinementschema, 500: Refinementschema}) async def refine( request, data: list[RefinementRequestSchema], # noqa: ANN001 ) -> tuple[int, Refinementschema | OptimizeErrorResponseSchema]: ph(request.user, "aiservice-refinement-called") ctx_data_list = [ RefinementContextData( original_source_code=opt.original_source_code, original_line_profiler_results=opt.original_line_profiler_results, original_code_runtime=opt.original_code_runtime, optimized_source_code=opt.optimized_source_code, read_only_dependency_code=opt.read_only_dependency_code, optimized_line_profiler_results=opt.optimized_line_profiler_results, optimized_code_runtime=opt.optimized_code_runtime, speedup=opt.speedup, optimized_explanation=opt.optimized_explanation, python_version=opt.python_version, function_references=opt.function_references, ) for opt in data ] ctx = BaseRefinerContext.get_dynamic_context( ctx_data=ctx_data_list[0], base_system_prompt=SYSTEM_PROMPT, base_user_prompt=USER_PROMPT ) trace_id = data[0].trace_id if not validate_trace_id(trace_id): return 400, OptimizeErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.") optimized_source_code_and_explanations_futures = [] for i in range(len(data)): if i != 0: ctx.data = ctx_data_list[i] optimized_source_code_and_explanations_futures.append( refinement(user_id=request.user, optimization_id=data[i].optimization_id, ctx=ctx) ) refinement_data = await asyncio.gather(*optimized_source_code_and_explanations_futures) # simple filtering mechanism, remove empty strings and remove duplicates after removing trailing and leading whitespaces, validate with libcst filtered_refined_optimizations = [] source_code_set = set() total_llm_cost = 0.0 for elem in refinement_data: if isinstance(elem, OptimizeErrorResponseSchema): continue total_llm_cost += elem.llm_cost try: ctx.validate_python_module(elem.source_code) except cst.ParserSyntaxError as e: # log exception with sentry sentry_sdk.capture_exception(e) debug_log_sensitive_data(f"ParserSyntaxError for source:\n{elem.source_code}") debug_log_sensitive_data(f"Traceback: {e}") continue except (ValueError, ValidationError) as exc: # Another one bites the Pydantic validation dust sentry_sdk.capture_exception(exc) debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{elem.source_code}") debug_log_sensitive_data(f"Traceback: {exc}") continue if (elem.source_code.strip() not in source_code_set) and elem.source_code != "": source_code_set.add(elem.source_code.strip()) filtered_refined_optimizations.append(elem) if hasattr(request, "should_log_features") and request.should_log_features: await log_features( trace_id=trace_id, user_id=request.user, optimizations_raw={ cei.optimization_id[:-4] + "refi": cei.source_code for cei in refinement_data if not isinstance(cei, OptimizeErrorResponseSchema) }, optimizations_post={ cei.optimization_id[:-4] + "refi": cei.source_code for cei in filtered_refined_optimizations }, explanations_raw={ cei.optimization_id[:-4] + "refi": cei.explanation for cei in refinement_data if not isinstance(cei, OptimizeErrorResponseSchema) }, explanations_post={ cei.optimization_id[:-4] + "refi": cei.explanation for cei in filtered_refined_optimizations }, ) await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost) return 200, Refinementschema( refinements=[ RefinementResponseItemschema( source_code=x.source_code, explanation=x.original_explanation, optimization_id=x.optimization_id ) for x in filtered_refined_optimizations ] )