Merge branch 'main' into chore/add-staging/docs

This commit is contained in:
HeshamHM28 2025-08-20 23:40:20 +03:00 committed by GitHub
commit 099cd00fbc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 1179 additions and 540 deletions

View file

@ -3,9 +3,7 @@
<a href="https://github.com/codeflash-ai/codeflash">
<img src="https://img.shields.io/github/commit-activity/m/codeflash-ai/codeflash" alt="GitHub commit activity">
</a>
<a href="https://pypi.org/project/codeflash/">
<img src="https://img.shields.io/pypi/dm/codeflash" alt="PyPI Downloads">
</a>
<a href="https://pypi.org/project/codeflash/"><img src="https://static.pepy.tech/badge/codeflash" alt="PyPI Downloads"></a>
<a href="https://pypi.org/project/codeflash/">
<img src="https://img.shields.io/pypi/v/codeflash?label=PyPI%20version" alt="PyPI Downloads">
</a>
@ -83,4 +81,4 @@ Join our community for support and discussions. If you have any questions, feel
## License
Codeflash is licensed under the BSL-1.1 License. See the LICENSE file for details.
Codeflash is licensed under the BSL-1.1 License. See the [LICENSE](https://github.com/codeflash-ai/codeflash/blob/main/codeflash/LICENSE) file for details.

View file

@ -3,7 +3,7 @@ Business Source License 1.1
Parameters
Licensor: CodeFlash Inc.
Licensed Work: Codeflash Client version 0.15.x
Licensed Work: Codeflash Client version 0.16.x
The Licensed Work is (c) 2024 CodeFlash Inc.
Additional Use Grant: None. Production use of the Licensed Work is only permitted
@ -13,7 +13,7 @@ Additional Use Grant: None. Production use of the Licensed Work is only permitte
Platform. Please visit codeflash.ai for further
information.
Change Date: 2029-07-03
Change Date: 2029-08-14
Change License: MIT

View file

@ -10,8 +10,9 @@ import requests
from pydantic.json import pydantic_encoder
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled
from codeflash.code_utils.env_utils import get_codeflash_api_key
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate
from codeflash.telemetry.posthog_cf import ph
@ -202,7 +203,7 @@ class AiServiceClient:
if response.status_code == 200:
optimizations_json = response.json()["optimizations"]
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
console.rule()
return [
OptimizedCandidate(
@ -248,7 +249,7 @@ class AiServiceClient:
}
for opt in request
]
logger.info(f"Refining {len(request)} optimizations…")
logger.debug(f"Refining {len(request)} optimizations…")
console.rule()
try:
response = self.make_ai_service_request("/refinement", payload=payload, timeout=600)
@ -259,7 +260,7 @@ class AiServiceClient:
if response.status_code == 200:
refined_optimizations = response.json()["refinements"]
logger.info(f"Generated {len(refined_optimizations)} candidate refinements.")
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
console.rule()
return [
OptimizedCandidate(
@ -339,7 +340,6 @@ class AiServiceClient:
if response.status_code == 200:
explanation: str = response.json()["explanation"]
logger.debug(f"New Explanation: {explanation}")
console.rule()
return explanation
try:

View file

@ -14,8 +14,9 @@ from pydantic.json import pydantic_encoder
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name, git_root_dir
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name
from codeflash.github.PrComment import FileDiffContent, PrComment
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.version import __version__
if TYPE_CHECKING:
@ -101,7 +102,7 @@ def get_user_id() -> Optional[str]:
if min_version and version.parse(min_version) > version.parse(__version__):
msg = "Your Codeflash CLI version is outdated. Please update to the latest version using `pip install --upgrade codeflash`."
console.print(f"[bold red]{msg}[/bold red]")
if console.quiet: # lsp
if is_LSP_enabled():
logger.debug(msg)
return f"Error: {msg}"
sys.exit(1)
@ -203,6 +204,9 @@ def create_staging(
generated_original_test_source: str,
function_trace_id: str,
coverage_message: str,
replay_tests: str,
concolic_tests: str,
root_dir: Path,
) -> Response:
"""Create a staging pull request, targeting the specified branch. (usually 'staging').
@ -215,12 +219,10 @@ def create_staging(
:param coverage_message: Coverage report or summary.
:return: The response object from the backend.
"""
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
build_file_changes = {
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
oldContent=original_code[p], newContent=new_code[p]
)
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(oldContent=original_code[p], newContent=new_code[p])
for p in original_code
}
@ -243,6 +245,8 @@ def create_staging(
"generatedTests": generated_original_test_source,
"traceId": function_trace_id,
"coverage_message": coverage_message,
"replayTests": replay_tests,
"concolicTests": concolic_tests,
}
return make_cfapi_request(endpoint="/create-staging", method="POST", payload=payload)

View file

@ -94,6 +94,7 @@ def parse_args() -> Namespace:
help="Path to the directory of the project, where all the pytest-benchmark tests are located.",
)
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
args, unknown_args = parser.parse_known_args()
sys.argv[:] = [sys.argv[0], *unknown_args]

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import logging
import os
from contextlib import contextmanager
from itertools import cycle
from typing import TYPE_CHECKING
@ -28,6 +29,10 @@ if TYPE_CHECKING:
DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG
console = Console()
if os.getenv("CODEFLASH_LSP"):
console.quiet = True
logging.basicConfig(
level=logging.INFO,
handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)],

View file

@ -195,6 +195,79 @@ class LastImportFinder(cst.CSTVisitor):
self.last_import_line = self.current_line
class DottedImportCollector(cst.CSTVisitor):
"""Collects all top-level imports from a Python module in normalized dotted format, including top-level conditional imports like `if TYPE_CHECKING:`.
Examples
--------
import os ==> "os"
import dbt.adapters.factory ==> "dbt.adapters.factory"
from pathlib import Path ==> "pathlib.Path"
from recce.adapter.base import BaseAdapter ==> "recce.adapter.base.BaseAdapter"
from typing import Any, List, Optional ==> "typing.Any", "typing.List", "typing.Optional"
from recce.util.lineage import ( build_column_key, filter_dependency_maps) ==> "recce.util.lineage.build_column_key", "recce.util.lineage.filter_dependency_maps"
"""
def __init__(self) -> None:
self.imports: set[str] = set()
self.depth = 0 # top-level
def get_full_dotted_name(self, expr: cst.BaseExpression) -> str:
if isinstance(expr, cst.Name):
return expr.value
if isinstance(expr, cst.Attribute):
return f"{self.get_full_dotted_name(expr.value)}.{expr.attr.value}"
return ""
def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
for statement in block.body:
if isinstance(statement, cst.SimpleStatementLine):
for child in statement.body:
if isinstance(child, cst.Import):
for alias in child.names:
module = self.get_full_dotted_name(alias.name)
asname = alias.asname.name.value if alias.asname else alias.name.value
if isinstance(asname, cst.Attribute):
self.imports.add(module)
else:
self.imports.add(module if module == asname else f"{module}.{asname}")
elif isinstance(child, cst.ImportFrom):
if child.module is None:
continue
module = self.get_full_dotted_name(child.module)
for alias in child.names:
if isinstance(alias, cst.ImportAlias):
name = alias.name.value
asname = alias.asname.name.value if alias.asname else name
self.imports.add(f"{module}.{asname}")
def visit_Module(self, node: cst.Module) -> None:
self.depth = 0
self._collect_imports_from_block(node)
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
self.depth += 1
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
self.depth -= 1
def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.depth += 1
def leave_ClassDef(self, node: cst.ClassDef) -> None:
self.depth -= 1
def visit_If(self, node: cst.If) -> None:
if self.depth == 0:
self._collect_imports_from_block(node.body)
def visit_Try(self, node: cst.Try) -> None:
if self.depth == 0:
self._collect_imports_from_block(node.body)
class ImportInserter(cst.CSTTransformer):
"""Transformer that inserts global statements after the last import."""
@ -329,9 +402,19 @@ def add_needed_imports_from_module(
except Exception as e:
logger.error(f"Error parsing source module code: {e}")
return dst_module_code
dotted_import_collector = DottedImportCollector()
try:
parsed_dst_module = cst.parse_module(dst_module_code)
parsed_dst_module.visit(dotted_import_collector)
except cst.ParserSyntaxError as e:
logger.exception(f"Syntax error in destination module code: {e}")
return dst_module_code # Return the original code if there's a syntax error
try:
for mod in gatherer.module_imports:
AddImportsVisitor.add_needed_import(dst_context, mod)
if mod not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod)
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
for mod, obj_seq in gatherer.object_mapping.items():
for obj in obj_seq:
@ -339,28 +422,29 @@ def add_needed_imports_from_module(
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
):
continue # Skip adding imports for helper functions already in the context
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
if f"{mod}.{obj}" not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
except Exception as e:
logger.exception(f"Error adding imports to destination module code: {e}")
return dst_module_code
for mod, asname in gatherer.module_aliases.items():
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
if f"{mod}.{asname}" not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
for mod, alias_pairs in gatherer.alias_mapping.items():
for alias_pair in alias_pairs:
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
continue
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
try:
parsed_module = cst.parse_module(dst_module_code)
except cst.ParserSyntaxError as e:
logger.exception(f"Syntax error in destination module code: {e}")
return dst_module_code # Return the original code if there's a syntax error
try:
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module)
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_dst_module)
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
return transformed_module.code.lstrip("\n")
except Exception as e:

View file

@ -39,16 +39,20 @@ def build_fully_qualified_name(function_name: str, code_context: CodeOptimizatio
return full_name
def generate_candidates(source_code_path: Path) -> list[str]:
def generate_candidates(source_code_path: Path) -> set[str]:
"""Generate all the possible candidates for coverage data based on the source code path."""
candidates = [source_code_path.name]
candidates = set()
candidates.add(source_code_path.name)
current_path = source_code_path.parent
last_added = source_code_path.name
while current_path != current_path.parent:
candidate_path = str(Path(current_path.name) / candidates[-1])
candidates.append(candidate_path)
candidate_path = str(Path(current_path.name) / last_added)
candidates.add(candidate_path)
last_added = candidate_path
current_path = current_path.parent
candidates.add(str(source_code_path))
return candidates

View file

@ -7,10 +7,11 @@ from functools import lru_cache
from pathlib import Path
from typing import Any, Optional
from codeflash.cli_cmds.console import console, logger
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import exit_with_message
from codeflash.code_utils.formatter import format_code
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
from codeflash.lsp.helpers import is_LSP_enabled
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
@ -34,11 +35,12 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
@lru_cache(maxsize=1)
def get_codeflash_api_key() -> str:
if console.quiet: # lsp
# prefer shell config over env var in lsp mode
api_key = read_api_key_from_shell_config()
else:
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
# prefer shell config over env var in lsp mode
api_key = (
read_api_key_from_shell_config()
if is_LSP_enabled()
else os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
)
api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/getting-started/codeflash-github-actions#add-your-api-key-to-your-repository-secrets]." # noqa
if not api_key:
@ -125,11 +127,6 @@ def is_ci() -> bool:
return bool(os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS"))
@lru_cache(maxsize=1)
def is_LSP_enabled() -> bool:
return console.quiet
def is_pr_draft() -> bool:
"""Check if the PR is draft. in the github action context."""
event = get_cached_gh_event_data()

View file

@ -13,6 +13,7 @@ from typing import Optional, Union
import isort
from codeflash.cli_cmds.console import console, logger
from codeflash.lsp.helpers import is_LSP_enabled
def generate_unified_diff(original: str, modified: str, from_file: str, to_file: str) -> str:
@ -44,9 +45,7 @@ def apply_formatter_cmds(
test_dir_str: Optional[str],
print_status: bool, # noqa
exit_on_failure: bool = True, # noqa
) -> tuple[Path, str]:
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
formatter_name = cmds[0].lower()
) -> tuple[Path, str, bool]:
should_make_copy = False
file_path = path
@ -54,9 +53,6 @@ def apply_formatter_cmds(
should_make_copy = True
file_path = Path(test_dir_str) / "temp.py"
if not cmds or formatter_name == "disabled":
return path, path.read_text(encoding="utf8")
if not path.exists():
msg = f"File {path} does not exist. Cannot apply formatter commands."
raise FileNotFoundError(msg)
@ -66,6 +62,7 @@ def apply_formatter_cmds(
file_token = "$file" # noqa: S105
changed = False
for command in cmds:
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
formatter_cmd_list = [file_path.as_posix() if chunk == file_token else chunk for chunk in formatter_cmd_list]
@ -74,6 +71,7 @@ def apply_formatter_cmds(
if result.returncode == 0:
if print_status:
console.rule(f"Formatted Successfully with: {command.replace('$file', path.name)}")
changed = True
else:
logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}")
except FileNotFoundError as e:
@ -88,7 +86,7 @@ def apply_formatter_cmds(
if exit_on_failure:
raise e from None
return file_path, file_path.read_text(encoding="utf8")
return file_path, file_path.read_text(encoding="utf8"), changed
def get_diff_lines_count(diff_output: str) -> int:
@ -109,13 +107,18 @@ def format_code(
print_status: bool = True, # noqa
exit_on_failure: bool = True, # noqa
) -> str:
if console.quiet:
# lsp mode
if is_LSP_enabled():
exit_on_failure = False
with tempfile.TemporaryDirectory() as test_dir_str:
if isinstance(path, str):
path = Path(path)
if isinstance(path, str):
path = Path(path)
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
if formatter_name == "disabled":
return path.read_text(encoding="utf8")
with tempfile.TemporaryDirectory() as test_dir_str:
original_code = path.read_text(encoding="utf8")
original_code_lines = len(original_code.split("\n"))
@ -126,10 +129,16 @@ def format_code(
original_temp = Path(test_dir_str) / "original_temp.py"
original_temp.write_text(original_code_without_opfunc, encoding="utf8")
formatted_temp, formatted_code = apply_formatter_cmds(
formatter_cmds, original_temp, test_dir_str, print_status=False
formatted_temp, formatted_code, changed = apply_formatter_cmds(
formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=exit_on_failure
)
if not changed:
logger.warning(
f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?"
)
return original_code
diff_output = generate_unified_diff(
original_code_without_opfunc, formatted_code, from_file=str(original_temp), to_file=str(formatted_temp)
)
@ -137,15 +146,22 @@ def format_code(
max_diff_lines = min(int(original_code_lines * 0.3), 50)
if diff_lines_count > max_diff_lines and max_diff_lines != -1:
logger.debug(
if diff_lines_count > max_diff_lines:
logger.warning(
f"Skipping formatting {path}: {diff_lines_count} lines would change (max: {max_diff_lines})"
)
return original_code
# TODO : We can avoid formatting the whole file again and only formatting the optimized code standalone and replace in formatted file above.
_, formatted_code = apply_formatter_cmds(
_, formatted_code, changed = apply_formatter_cmds(
formatter_cmds, path, test_dir_str=None, print_status=print_status, exit_on_failure=exit_on_failure
)
if not changed:
logger.warning(
f"No changes detected in {path} after formatting, are you sure you have valid formatter commands?"
)
return original_code
logger.debug(f"Formatted {path} with commands: {formatter_cmds}")
return formatted_code

View file

@ -9,13 +9,14 @@ import time
from functools import cache
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import git
from rich.prompt import Confirm
from unidiff import PatchSet
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.compat import codeflash_cache_dir
from codeflash.code_utils.config_consts import N_CANDIDATES
if TYPE_CHECKING:
@ -192,3 +193,80 @@ def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None:
return None
else:
return last_commit.author.name
worktree_dirs = codeflash_cache_dir / "worktrees"
patches_dir = codeflash_cache_dir / "patches"
def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
repository = git.Repo(worktree_dir, search_parent_directories=True)
repository.git.commit("-am", commit_message, "--no-verify")
def create_detached_worktree(module_root: Path) -> Optional[Path]:
if not check_running_in_git_repo(module_root):
logger.warning("Module is not in a git repository. Skipping worktree creation.")
return None
git_root = git_root_dir()
current_time_str = time.strftime("%Y%m%d-%H%M%S")
worktree_dir = worktree_dirs / f"{git_root.name}-{current_time_str}"
repository = git.Repo(git_root, search_parent_directories=True)
repository.git.worktree("add", "-d", str(worktree_dir))
# Get uncommitted diff from the original repo
repository.git.add("-N", ".") # add the index for untracked files to be included in the diff
uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True)
if not uni_diff_text.strip():
logger.info("No uncommitted changes to copy to worktree.")
return worktree_dir
# Write the diff to a temporary file
with tempfile.NamedTemporaryFile(mode="w", suffix=".codeflash.patch", delete=False) as tmp_patch_file:
tmp_patch_file.write(uni_diff_text + "\n") # the new line here is a must otherwise the last hunk won't be valid
tmp_patch_file.flush()
patch_path = Path(tmp_patch_file.name).resolve()
# Apply the patch inside the worktree
try:
subprocess.run(
["git", "apply", "--ignore-space-change", "--ignore-whitespace", patch_path],
cwd=worktree_dir,
check=True,
)
create_worktree_snapshot_commit(worktree_dir, "Initial Snapshot")
except subprocess.CalledProcessError as e:
logger.error(f"Failed to apply patch to worktree: {e}")
return worktree_dir
def remove_worktree(worktree_dir: Path) -> None:
try:
repository = git.Repo(worktree_dir, search_parent_directories=True)
repository.git.worktree("remove", "--force", worktree_dir)
except Exception:
logger.exception(f"Failed to remove worktree: {worktree_dir}")
def create_diff_patch_from_worktree(worktree_dir: Path, files: list[str], fto_name: str) -> Path:
repository = git.Repo(worktree_dir, search_parent_directories=True)
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)
if not uni_diff_text:
logger.warning("No changes found in worktree.")
return None
if not uni_diff_text.endswith("\n"):
uni_diff_text += "\n"
# write to patches_dir
patches_dir.mkdir(parents=True, exist_ok=True)
patch_path = patches_dir / f"{worktree_dir.name}.{fto_name}.patch"
with patch_path.open("w", encoding="utf8") as f:
f.write(uni_diff_text)
return patch_path

View file

@ -208,7 +208,7 @@ def get_functions_to_optimize(
logger.info("Finding all functions modified in the current git diff ...")
console.rule()
ph("cli-optimizing-git-diff")
functions = get_functions_within_git_diff()
functions = get_functions_within_git_diff(uncommitted_changes=False)
filtered_modified_functions, functions_count = filter_functions(
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
)
@ -224,8 +224,8 @@ def get_functions_to_optimize(
return filtered_modified_functions, functions_count, trace_file_path
def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]:
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=False)
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]: # noqa: FBT001
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes)
modified_functions: dict[str, list[FunctionToOptimize]] = {}
for path_str, lines_in_file in modified_lines.items():
path = Path(path_str)

View file

@ -1,4 +0,0 @@
# Silence the console module to prevent stdout pollution
from codeflash.cli_cmds.console import console
console.quiet = True

View file

@ -9,15 +9,15 @@ from typing import TYPE_CHECKING
from pygls import uris
from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
from codeflash.code_utils.git_utils import create_diff_patch_from_worktree
from codeflash.code_utils.shell_utils import save_api_key_to_rc
from codeflash.discovery.functions_to_optimize import filter_functions, get_functions_within_git_diff
from codeflash.either import is_successful
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol
if TYPE_CHECKING:
from lsprotocol import types
from codeflash.models.models import GeneratedTestsList, OptimizationSet
@dataclass
class OptimizableFunctionsParams:
@ -38,6 +38,23 @@ class ProvideApiKeyParams:
server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
@server.feature("getOptimizableFunctionsInCurrentDiff")
def get_functions_in_current_git_diff(
server: CodeflashLanguageServer, _params: OptimizableFunctionsParams
) -> dict[str, str | list[str]]:
functions = get_functions_within_git_diff(uncommitted_changes=True)
file_to_funcs_to_optimize, _ = filter_functions(
modified_functions=functions,
tests_root=server.optimizer.test_cfg.tests_root,
ignore_paths=[],
project_root=server.optimizer.args.project_root,
module_root=server.optimizer.args.module_root,
previous_checkpoint_functions={},
)
qualified_names: list[str] = [func.qualified_name for funcs in file_to_funcs_to_optimize.values() for func in funcs]
return {"functions": qualified_names, "status": "success"}
@server.feature("getOptimizableFunctions")
def get_optimizable_functions(
server: CodeflashLanguageServer, params: OptimizableFunctionsParams
@ -45,44 +62,21 @@ def get_optimizable_functions(
file_path = Path(uris.to_fs_path(params.textDocument.uri))
server.show_message_log(f"Getting optimizable functions for: {file_path}", "Info")
# Save original args to restore later
original_file = getattr(server.optimizer.args, "file", None)
original_function = getattr(server.optimizer.args, "function", None)
original_checkpoint = getattr(server.optimizer.args, "previous_checkpoint_functions", None)
server.optimizer.args.file = file_path
server.optimizer.args.function = None # Always get ALL functions, not just one
server.optimizer.args.previous_checkpoint_functions = False
server.show_message_log(f"Original args - file: {original_file}, function: {original_function}", "Info")
server.show_message_log(f"Calling get_optimizable_functions for {server.optimizer.args.file}...", "Info")
optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions()
try:
# Set temporary args for this request only
server.optimizer.args.file = file_path
server.optimizer.args.function = None # Always get ALL functions, not just one
server.optimizer.args.previous_checkpoint_functions = False
path_to_qualified_names = {}
for functions in optimizable_funcs.values():
path_to_qualified_names[file_path] = [func.qualified_name for func in functions]
server.show_message_log("Calling get_optimizable_functions...", "Info")
optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions()
path_to_qualified_names = {}
for path, functions in optimizable_funcs.items():
path_to_qualified_names[path.as_posix()] = [func.qualified_name for func in functions]
server.show_message_log(
f"Found {len(path_to_qualified_names)} files with functions: {path_to_qualified_names}", "Info"
)
return path_to_qualified_names
finally:
# Restore original args to prevent state corruption
if original_file is not None:
server.optimizer.args.file = original_file
if original_function is not None:
server.optimizer.args.function = original_function
else:
server.optimizer.args.function = None
if original_checkpoint is not None:
server.optimizer.args.previous_checkpoint_functions = original_checkpoint
server.show_message_log(
f"Restored args - file: {server.optimizer.args.file}, function: {server.optimizer.args.function}", "Info"
)
server.show_message_log(
f"Found {len(path_to_qualified_names)} files with functions: {path_to_qualified_names}", "Info"
)
return path_to_qualified_names
@server.feature("initializeFunctionOptimization")
@ -91,10 +85,15 @@ def initialize_function_optimization(
) -> dict[str, str]:
file_path = Path(uris.to_fs_path(params.textDocument.uri))
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info")
if server.optimizer is None:
_initialize_optimizer_if_valid(server)
server.optimizer.worktree_mode()
original_args, _ = server.optimizer.original_args_and_test_cfg
# IMPORTANT: Store the specific function for optimization, but don't corrupt global state
server.optimizer.args.function = params.functionName
server.optimizer.args.file = file_path
original_relative_file_path = file_path.relative_to(original_args.project_root)
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
server.optimizer.args.previous_checkpoint_functions = False
server.show_message_log(
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
@ -103,7 +102,12 @@ def initialize_function_optimization(
optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions()
if not optimizable_funcs:
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
return {"functionName": params.functionName, "status": "not found", "args": None}
return {
"functionName": params.functionName,
"status": "error",
"message": "function is no found or not optimizable",
"args": None,
}
fto = optimizable_funcs.popitem()[1][0]
server.optimizer.current_function_being_optimized = fto
@ -172,177 +176,135 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
return {"status": "error", "message": "something went wrong while saving the api key"}
@server.feature("prepareOptimization")
def prepare_optimization(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]:
current_function = server.optimizer.current_function_being_optimized
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
validated_original_code, original_module_ast = module_prep_result
function_optimizer = server.optimizer.create_function_optimizer(
current_function,
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
original_module_ast=original_module_ast,
original_module_path=current_function.file_path,
)
server.optimizer.current_function_optimizer = function_optimizer
if not function_optimizer:
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
initialization_result = function_optimizer.can_be_optimized()
if not is_successful(initialization_result):
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
return {"functionName": params.functionName, "status": "success", "message": "Optimization preparation completed"}
@server.feature("generateTests")
def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]:
function_optimizer = server.optimizer.current_function_optimizer
if not function_optimizer:
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
initialization_result = function_optimizer.can_be_optimized()
if not is_successful(initialization_result):
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
test_setup_result = function_optimizer.generate_and_instrument_tests(
code_context, should_run_experiment=should_run_experiment
)
if not is_successful(test_setup_result):
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
generated_tests_list: GeneratedTestsList
optimizations_set: OptimizationSet
generated_tests_list, _, concolic__test_str, optimizations_set = test_setup_result.unwrap()
generated_tests: list[str] = [
generated_test.generated_original_test_source for generated_test in generated_tests_list.generated_tests
]
optimizations_dict = {
candidate.optimization_id: {"source_code": candidate.source_code.markdown, "explanation": candidate.explanation}
for candidate in optimizations_set.control + optimizations_set.experiment
}
return {
"functionName": params.functionName,
"status": "success",
"message": {"generated_tests": generated_tests, "optimizations": optimizations_dict},
}
@server.feature("performFunctionOptimization")
def perform_function_optimization( # noqa: PLR0911
server: CodeflashLanguageServer, params: FunctionOptimizationParams
) -> dict[str, str]:
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
current_function = server.optimizer.current_function_being_optimized
try:
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
current_function = server.optimizer.current_function_being_optimized
if not current_function:
server.show_message_log(f"No current function being optimized for {params.functionName}", "Error")
return {
"functionName": params.functionName,
"status": "error",
"message": "No function currently being optimized",
}
if not current_function:
server.show_message_log(f"No current function being optimized for {params.functionName}", "Error")
return {
"functionName": params.functionName,
"status": "error",
"message": "No function currently being optimized",
}
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
validated_original_code, original_module_ast = module_prep_result
validated_original_code, original_module_ast = module_prep_result
function_optimizer = server.optimizer.create_function_optimizer(
current_function,
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
original_module_ast=original_module_ast,
original_module_path=current_function.file_path,
function_to_tests=server.optimizer.discovered_tests or {},
)
server.optimizer.current_function_optimizer = function_optimizer
if not function_optimizer:
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
initialization_result = function_optimizer.can_be_optimized()
if not is_successful(initialization_result):
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
test_setup_result = function_optimizer.generate_and_instrument_tests(
code_context, should_run_experiment=should_run_experiment
)
if not is_successful(test_setup_result):
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
(
generated_tests,
function_to_concolic_tests,
concolic_test_str,
optimizations_set,
generated_test_paths,
generated_perf_test_paths,
instrumented_unittests_created_for_function,
original_conftest_content,
) = test_setup_result.unwrap()
baseline_setup_result = function_optimizer.setup_and_establish_baseline(
code_context=code_context,
original_helper_code=original_helper_code,
function_to_concolic_tests=function_to_concolic_tests,
generated_test_paths=generated_test_paths,
generated_perf_test_paths=generated_perf_test_paths,
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
original_conftest_content=original_conftest_content,
)
if not is_successful(baseline_setup_result):
return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()}
(
function_to_optimize_qualified_name,
function_to_all_tests,
original_code_baseline,
test_functions_to_remove,
file_path_to_helper_classes,
) = baseline_setup_result.unwrap()
best_optimization = function_optimizer.find_and_process_best_optimization(
optimizations_set=optimizations_set,
code_context=code_context,
original_code_baseline=original_code_baseline,
original_helper_code=original_helper_code,
file_path_to_helper_classes=file_path_to_helper_classes,
function_to_optimize_qualified_name=function_to_optimize_qualified_name,
function_to_all_tests=function_to_all_tests,
generated_tests=generated_tests,
test_functions_to_remove=test_functions_to_remove,
concolic_test_str=concolic_test_str,
)
if not best_optimization:
server.show_message_log(
f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning"
function_optimizer = server.optimizer.create_function_optimizer(
current_function,
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
original_module_ast=original_module_ast,
original_module_path=current_function.file_path,
function_to_tests=server.optimizer.discovered_tests or {},
)
server.optimizer.current_function_optimizer = function_optimizer
if not function_optimizer:
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
initialization_result = function_optimizer.can_be_optimized()
if not is_successful(initialization_result):
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
test_setup_result = function_optimizer.generate_and_instrument_tests(
code_context, should_run_experiment=should_run_experiment
)
if not is_successful(test_setup_result):
return {"functionName": params.functionName, "status": "error", "message": test_setup_result.failure()}
(
generated_tests,
function_to_concolic_tests,
concolic_test_str,
optimizations_set,
generated_test_paths,
generated_perf_test_paths,
instrumented_unittests_created_for_function,
original_conftest_content,
) = test_setup_result.unwrap()
baseline_setup_result = function_optimizer.setup_and_establish_baseline(
code_context=code_context,
original_helper_code=original_helper_code,
function_to_concolic_tests=function_to_concolic_tests,
generated_test_paths=generated_test_paths,
generated_perf_test_paths=generated_perf_test_paths,
instrumented_unittests_created_for_function=instrumented_unittests_created_for_function,
original_conftest_content=original_conftest_content,
)
if not is_successful(baseline_setup_result):
return {"functionName": params.functionName, "status": "error", "message": baseline_setup_result.failure()}
(
function_to_optimize_qualified_name,
function_to_all_tests,
original_code_baseline,
test_functions_to_remove,
file_path_to_helper_classes,
) = baseline_setup_result.unwrap()
best_optimization = function_optimizer.find_and_process_best_optimization(
optimizations_set=optimizations_set,
code_context=code_context,
original_code_baseline=original_code_baseline,
original_helper_code=original_helper_code,
file_path_to_helper_classes=file_path_to_helper_classes,
function_to_optimize_qualified_name=function_to_optimize_qualified_name,
function_to_all_tests=function_to_all_tests,
generated_tests=generated_tests,
test_functions_to_remove=test_functions_to_remove,
concolic_test_str=concolic_test_str,
)
if not best_optimization:
server.show_message_log(
f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning"
)
return {
"functionName": params.functionName,
"status": "error",
"message": f"No best optimizations found for function {function_to_optimize_qualified_name}",
}
# generate a patch for the optimization
relative_file_paths = [code_string.file_path for code_string in code_context.read_writable_code.code_strings]
patch_file = create_diff_patch_from_worktree(
server.optimizer.current_worktree,
relative_file_paths,
server.optimizer.current_function_optimizer.function_to_optimize.qualified_name,
)
optimized_source = best_optimization.candidate.source_code.markdown
speedup = original_code_baseline.runtime / best_optimization.runtime
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")
return {
"functionName": params.functionName,
"status": "error",
"message": f"No best optimizations found for function {function_to_optimize_qualified_name}",
"status": "success",
"message": "Optimization completed successfully",
"extra": f"Speedup: {speedup:.2f}x faster",
"optimization": optimized_source,
"patch_file": str(patch_file),
"explanation": best_optimization.explanation_v2,
}
finally:
cleanup_the_optimizer(server)
optimized_source = best_optimization.candidate.source_code.markdown
speedup = original_code_baseline.runtime / best_optimization.runtime
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")
# CRITICAL: Clear the function filter after optimization to prevent state corruption
def cleanup_the_optimizer(server: CodeflashLanguageServer) -> None:
server.optimizer.cleanup_temporary_paths()
# restore args and test cfg
if server.optimizer.original_args_and_test_cfg:
server.optimizer.args, server.optimizer.test_cfg = server.optimizer.original_args_and_test_cfg
server.optimizer.args.function = None
server.show_message_log("Cleared function filter to prevent state corruption", "Info")
return {
"functionName": params.functionName,
"status": "success",
"message": "Optimization completed successfully",
"extra": f"Speedup: {speedup:.2f}x faster",
"optimization": optimized_source,
}
server.optimizer.current_worktree = None
server.optimizer.current_function_optimizer = None

7
codeflash/lsp/helpers.py Normal file
View file

@ -0,0 +1,7 @@
import os
from functools import lru_cache
@lru_cache(maxsize=1)
def is_LSP_enabled() -> bool:
return os.getenv("CODEFLASH_LSP", default="false").lower() == "true"

View file

@ -11,6 +11,8 @@ from pygls.server import LanguageServer
if TYPE_CHECKING:
from lsprotocol.types import InitializeParams, InitializeResult
from codeflash.optimization.optimizer import Optimizer
class CodeflashLanguageServerProtocol(LanguageServerProtocol):
_server: CodeflashLanguageServer
@ -26,7 +28,6 @@ class CodeflashLanguageServerProtocol(LanguageServerProtocol):
pyproject_toml_path = self._find_pyproject_toml(workspace_path)
if pyproject_toml_path:
server.prepare_optimizer_arguments(pyproject_toml_path)
server.show_message(f"Found pyproject.toml at: {pyproject_toml_path}")
else:
server.show_message("No pyproject.toml found in workspace.")
else:
@ -44,7 +45,7 @@ class CodeflashLanguageServerProtocol(LanguageServerProtocol):
class CodeflashLanguageServer(LanguageServer):
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
super().__init__(*args, **kwargs)
self.optimizer = None
self.optimizer: Optimizer | None = None
self.args = None
def prepare_optimizer_arguments(self, config_file: Path) -> None:
@ -53,6 +54,7 @@ class CodeflashLanguageServer(LanguageServer):
args = parse_args()
args.config_file = config_file
args.no_pr = True # LSP server should not create PRs
args.worktree = True
args = process_pyproject_config(args)
self.args = args
# avoid initializing the optimizer during initialization, because it can cause an error if the api key is invalid

View file

@ -21,7 +21,8 @@ def setup_logging() -> logging.Logger:
# Set up stderr handler for VS Code output channel with [LSP-Server] prefix
handler = logging.StreamHandler(sys.stderr)
handler.setFormatter(logging.Formatter("[LSP-Server] %(asctime)s [%(levelname)s]: %(message)s"))
# adding the :::: here for the client to easily extract the message from the log
handler.setFormatter(logging.Formatter("[LSP-Server] %(asctime)s [%(levelname)s]::::%(message)s"))
# Configure root logger
root_logger.addHandler(handler)

View file

@ -91,6 +91,7 @@ class FunctionSource:
class BestOptimization(BaseModel):
candidate: OptimizedCandidate
explanation_v2: Optional[str] = None
helper_functions: list[FunctionSource]
code_context: CodeOptimizationContext
runtime: int

View file

@ -380,6 +380,7 @@ class FunctionOptimizer:
console.rule()
candidates = deque(candidates)
refinement_done = False
line_profiler_done = False
future_all_refinements: list[concurrent.futures.Future] = []
ast_code_to_id = {}
valid_optimizations = []
@ -400,19 +401,45 @@ class FunctionOptimizer:
if self.experiment_id
else None,
)
try:
candidate_index = 0
original_len = len(candidates)
while candidates:
candidate_index = 0
original_len = len(candidates)
# TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls,
# TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably
while True:
try:
if len(candidates) > 0:
candidate = candidates.popleft()
else:
if not line_profiler_done:
logger.debug("all candidates processed, await candidates from line profiler")
concurrent.futures.wait([future_line_profile_results])
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
original_len += len(line_profile_results)
logger.info(
f"Added results from line profiler to candidates, total candidates now: {original_len}"
)
line_profiler_done = True
continue
if line_profiler_done and not refinement_done:
concurrent.futures.wait(future_all_refinements)
refinement_response = []
for future_refinement in future_all_refinements:
possible_refinement = future_refinement.result()
if len(possible_refinement) > 0: # if the api returns a valid response
refinement_response.append(possible_refinement[0])
candidates.extend(refinement_response)
original_len += len(refinement_response)
logger.info(
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
)
refinement_done = True
continue
if line_profiler_done and refinement_done:
logger.debug("everything done, exiting")
break
candidate_index += 1
line_profiler_done = True if future_line_profile_results is None else future_line_profile_results.done()
if line_profiler_done and (future_line_profile_results is not None):
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
original_len += len(line_profile_results)
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
future_line_profile_results = None
candidate = candidates.popleft()
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
@ -440,6 +467,9 @@ class FunctionOptimizer:
# check if this code has been evaluated before by checking the ast normalized code string
normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip()))
if normalized_code in ast_code_to_id:
logger.warning(
"Current candidate has been encountered before in testing, Skipping optimization candidate."
)
past_opt_id = ast_code_to_id[normalized_code]["optimization_id"]
# update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes
speedup_ratios[candidate.optimization_id] = speedup_ratios[past_opt_id]
@ -471,7 +501,6 @@ class FunctionOptimizer:
file_path_to_helper_classes=file_path_to_helper_classes,
)
console.rule()
if not is_successful(run_results):
optimized_runtimes[candidate.optimization_id] = None
is_correct[candidate.optimization_id] = False
@ -525,7 +554,6 @@ class FunctionOptimizer:
optimized_runtime_ns=candidate_replay_runtime,
)
benchmark_tree.add(f"{benchmark_key}: {replay_perf_gain[benchmark_key] * 100:.1f}%")
best_optimization = BestOptimization(
candidate=candidate,
helper_functions=code_context.helper_functions,
@ -568,36 +596,12 @@ class FunctionOptimizer:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
if (
(not len(candidates)) and (not line_profiler_done)
): # all original candidates processed but lp results haven't been processed, doesn't matter at the moment if we're done refining or not
concurrent.futures.wait([future_line_profile_results])
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
original_len += len(line_profile_results)
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
future_line_profile_results = None
# all original candidates and lp candidates processed, collect refinement candidates and append to candidate list
if (not len(candidates)) and line_profiler_done and not refinement_done:
# waiting just in case not all calls are finished, nothing else to do
concurrent.futures.wait(future_all_refinements)
refinement_response = []
for future_refinement in future_all_refinements:
possible_refinement = future_refinement.result()
if len(possible_refinement) > 0: # if the api returns a valid response
refinement_response.append(possible_refinement[0])
candidates.extend(refinement_response)
logger.info(f"Added {len(refinement_response)} candidates from refinement")
original_len += len(refinement_response)
refinement_done = True
except KeyboardInterrupt as e:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
logger.exception(f"Optimization interrupted: {e}")
raise
except KeyboardInterrupt as e:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
logger.exception(f"Optimization interrupted: {e}")
raise
if not valid_optimizations:
return None
# need to figure out the best candidate here before we return best_optimization
@ -741,7 +745,9 @@ class FunctionOptimizer:
file_to_code_context = optimized_context.file_to_path()
optimized_code = file_to_code_context.get(str(path.relative_to(self.project_root)), "")
new_code = format_code(self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True)
new_code = format_code(
self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True, exit_on_failure=False
)
if should_sort_imports:
new_code = sort_imports(new_code)
@ -750,7 +756,11 @@ class FunctionOptimizer:
module_abspath = hp.file_path
hp_source_code = hp.source_code
formatted_helper_code = format_code(
self.args.formatter_cmds, module_abspath, optimized_code=hp_source_code, check_diff=True
self.args.formatter_cmds,
module_abspath,
optimized_code=hp_source_code,
check_diff=True,
exit_on_failure=False,
)
if should_sort_imports:
formatted_helper_code = sort_imports(formatted_helper_code)
@ -1099,11 +1109,6 @@ class FunctionOptimizer:
if best_optimization:
logger.info("Best candidate:")
code_print(best_optimization.candidate.source_code.flat)
console.print(
Panel(
best_optimization.candidate.explanation, title="Best Candidate Explanation", border_style="blue"
)
)
processed_benchmark_info = None
if self.args.benchmark:
processed_benchmark_info = process_benchmark_data(
@ -1153,7 +1158,6 @@ class FunctionOptimizer:
original_helper_code,
code_context,
)
self.log_successful_optimization(explanation, generated_tests, exp_type)
return best_optimization
def process_review(
@ -1227,6 +1231,10 @@ class FunctionOptimizer:
file_path=explanation.file_path,
benchmark_details=explanation.benchmark_details,
)
self.log_successful_optimization(new_explanation, generated_tests, exp_type)
best_optimization.explanation_v2 = new_explanation.explanation_message()
data = {
"original_code": original_code_combined,
"new_code": new_code_combined,
@ -1239,25 +1247,62 @@ class FunctionOptimizer:
"coverage_message": coverage_message,
"replay_tests": replay_tests,
"concolic_tests": concolic_tests,
"root_dir": self.project_root,
}
if not self.args.no_pr and not self.args.staging_review:
raise_pr = not self.args.no_pr
if raise_pr and not self.args.staging_review:
data["git_remote"] = self.args.git_remote
check_create_pr(**data)
elif self.args.staging_review:
create_staging(**data)
response = create_staging(**data)
if response.status_code == 200:
staging_url = f"https://app.codeflash.ai/review-optimizations/{self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id}"
console.print(
Panel(
f"[bold green]✅ Staging created:[/bold green]\n[link={staging_url}]{staging_url}[/link]",
title="Staging Link",
border_style="green",
)
)
else:
console.print(
Panel(
f"[bold red]❌ Failed to create staging[/bold red]\nStatus: {response.status_code}",
title="Staging Error",
border_style="red",
)
)
else:
# Mark optimization success since no PR will be created
mark_optimization_success(
trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None
)
if ((not self.args.no_pr) or not self.args.staging_review) and (
self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function)
# If worktree mode, do not revert code and helpers,, otherwise we would have an empty diff when writing the patch in the lsp
if self.args.worktree:
return
if raise_pr and (
self.args.all
or env_utils.get_pr_number()
or self.args.replay_test
or (self.args.file and not self.args.function)
):
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
self.revert_code_and_helpers(original_helper_code)
return
if self.args.staging_review:
# always revert code and helpers when staging review
self.revert_code_and_helpers(original_helper_code)
return
def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
def establish_original_code_baseline(
self,

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import ast
import copy
import os
import tempfile
import time
@ -14,6 +15,13 @@ from codeflash.cli_cmds.console import console, logger, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file
from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft
from codeflash.code_utils.git_utils import (
check_running_in_git_repo,
create_detached_worktree,
create_diff_patch_from_worktree,
create_worktree_snapshot_commit,
remove_worktree,
)
from codeflash.either import is_successful
from codeflash.models.models import ValidCode
from codeflash.telemetry.posthog_cf import ph
@ -48,6 +56,9 @@ class Optimizer:
self.functions_checkpoint: CodeflashRunCheckpoint | None = None
self.current_function_being_optimized: FunctionToOptimize | None = None # current only for the LSP
self.current_function_optimizer: FunctionOptimizer | None = None
self.current_worktree: Path | None = None
self.original_args_and_test_cfg: tuple[Namespace, TestConfig] | None = None
self.patch_files: list[Path] = []
def run_benchmarks(
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int
@ -252,6 +263,10 @@ class Optimizer:
if self.args.no_draft and is_pr_draft():
logger.warning("PR is in draft mode, skipping optimization")
return
if self.args.worktree:
self.worktree_mode()
cleanup_paths(Optimizer.find_leftover_instrumented_test_files(self.test_cfg.tests_root))
function_optimizer = None
@ -260,7 +275,6 @@ class Optimizer:
file_to_funcs_to_optimize, num_optimizable_functions
)
optimizations_found: int = 0
function_iterator_count: int = 0
if self.args.test_framework == "pytest":
self.test_cfg.concolic_test_root_dir = Path(
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
@ -296,8 +310,8 @@ class Optimizer:
except Exception as e:
logger.debug(f"Could not rank functions in {original_module_path}: {e}")
for function_to_optimize in functions_to_optimize:
function_iterator_count += 1
for i, function_to_optimize in enumerate(functions_to_optimize):
function_iterator_count = i + 1
logger.info(
f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: "
f"{function_to_optimize.qualified_name}"
@ -327,6 +341,23 @@ class Optimizer:
)
if is_successful(best_optimization):
optimizations_found += 1
# create a diff patch for successful optimization
if self.current_worktree:
read_writable_code = best_optimization.unwrap().code_context.read_writable_code
relative_file_paths = [
code_string.file_path for code_string in read_writable_code.code_strings
]
patch_path = create_diff_patch_from_worktree(
self.current_worktree,
relative_file_paths,
self.current_function_optimizer.function_to_optimize.qualified_name,
)
self.patch_files.append(patch_path)
if i < len(functions_to_optimize) - 1:
create_worktree_snapshot_commit(
self.current_worktree,
f"Optimizing {functions_to_optimize[i + 1].qualified_name}",
)
else:
logger.warning(best_optimization.failure())
console.rule()
@ -337,6 +368,10 @@ class Optimizer:
function_optimizer.cleanup_generated_files()
ph("cli-optimize-run-finished", {"optimizations_found": optimizations_found})
if len(self.patch_files) > 0:
logger.info(
f"Created {len(self.patch_files)} patch(es) ({[str(patch_path) for patch_path in self.patch_files]})"
)
if self.functions_checkpoint:
self.functions_checkpoint.cleanup()
if hasattr(self.args, "command") and self.args.command == "optimize":
@ -382,14 +417,60 @@ class Optimizer:
cleanup_paths([self.replay_tests_dir])
def cleanup_temporary_paths(self) -> None:
if self.current_function_optimizer:
self.current_function_optimizer.cleanup_generated_files()
if hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir.cleanup()
del get_run_tmp_file.tmpdir
if self.current_worktree:
remove_worktree(self.current_worktree)
return
if self.current_function_optimizer:
self.current_function_optimizer.cleanup_generated_files()
cleanup_paths([self.test_cfg.concolic_test_root_dir, self.replay_tests_dir])
def worktree_mode(self) -> None:
if self.current_worktree:
return
if check_running_in_git_repo(self.args.module_root):
worktree_dir = create_detached_worktree(self.args.module_root)
if worktree_dir is None:
logger.warning("Failed to create worktree. Skipping optimization.")
return
self.current_worktree = worktree_dir
self.mutate_args_for_worktree_mode(worktree_dir)
def mutate_args_for_worktree_mode(self, worktree_dir: Path) -> None:
saved_args = copy.deepcopy(self.args)
saved_test_cfg = copy.deepcopy(self.test_cfg)
self.original_args_and_test_cfg = (saved_args, saved_test_cfg)
project_root = self.args.project_root
module_root = self.args.module_root
relative_module_root = module_root.relative_to(project_root)
relative_optimized_file = self.args.file.relative_to(project_root) if self.args.file else None
relative_tests_root = self.test_cfg.tests_root.relative_to(project_root)
relative_benchmarks_root = (
self.args.benchmarks_root.relative_to(project_root) if self.args.benchmarks_root else None
)
self.args.module_root = worktree_dir / relative_module_root
self.args.project_root = worktree_dir
self.args.test_project_root = worktree_dir
self.args.tests_root = worktree_dir / relative_tests_root
if relative_benchmarks_root:
self.args.benchmarks_root = worktree_dir / relative_benchmarks_root
self.test_cfg.project_root_path = worktree_dir
self.test_cfg.tests_project_rootdir = worktree_dir
self.test_cfg.tests_root = worktree_dir / relative_tests_root
if relative_benchmarks_root:
self.test_cfg.benchmark_tests_root = worktree_dir / relative_benchmarks_root
if relative_optimized_file is not None:
self.args.file = worktree_dir / relative_optimized_file
def run_with_args(args: Namespace) -> None:
optimizer = None

View file

@ -10,12 +10,7 @@ from codeflash.api import cfapi
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_replacer import is_zero_diff
from codeflash.code_utils.git_utils import (
check_and_push_branch,
get_current_branch,
get_repo_owner_and_name,
git_root_dir,
)
from codeflash.code_utils.git_utils import check_and_push_branch, get_current_branch, get_repo_owner_and_name
from codeflash.code_utils.github_utils import github_pr_url
from codeflash.code_utils.tabulate import tabulate
from codeflash.code_utils.time_utils import format_perf, format_time
@ -188,6 +183,7 @@ def check_create_pr(
coverage_message: str,
replay_tests: str,
concolic_tests: str,
root_dir: Path,
git_remote: Optional[str] = None,
) -> None:
pr_number: Optional[int] = env_utils.get_pr_number()
@ -196,9 +192,9 @@ def check_create_pr(
if pr_number is not None:
logger.info(f"Suggesting changes to PR #{pr_number} ...")
owner, repo = get_repo_owner_and_name(git_repo)
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
build_file_changes = {
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(
oldContent=original_code[p], newContent=new_code[p]
)
for p in original_code
@ -247,10 +243,10 @@ def check_create_pr(
if not check_and_push_branch(git_repo, git_remote, wait_for_push=True):
logger.warning("⏭️ Branch is not pushed, skipping PR creation...")
return
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
base_branch = get_current_branch()
build_file_changes = {
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(
oldContent=original_code[p], newContent=new_code[p]
)
for p in original_code

View file

@ -10,9 +10,9 @@ from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
from codeflash.code_utils.concolic_utils import clean_concolic_tests
from codeflash.code_utils.env_utils import is_LSP_enabled
from codeflash.code_utils.static_analysis import has_typed_parameters
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.verification_utils import TestConfig

View file

@ -1,2 +1,2 @@
# These version placeholders will be replaced by uv-dynamic-versioning during build.
__version__ = "0.16.3"
__version__ = "0.16.6"

View file

@ -57,7 +57,7 @@ Codeflash runs tests for the target function using either pytest or unittest fra
#### Performance benchmarking
Codeflash implements [several techniques](/codeflash-concepts/benchmarking.md) to measure code performance accurately. In particular, it runs multiple iterations of the code in a loop to determine the best performance with the minimum runtime. Codeflash compares the performance of the original code against the optimization, requiring at least a 10% speed improvement before considering it to be faster. This approach eliminates most runtime measurement variability, even on noisy CI systems and virtual machines. The final runtime Codeflash reports is the minimum total time it took to run all the test cases.
Codeflash implements [several techniques](/codeflash-concepts/benchmarking) to measure code performance accurately. In particular, it runs multiple iterations of the code in a loop to determine the best performance with the minimum runtime. Codeflash compares the performance of the original code against the optimization, requiring at least a 10% speed improvement before considering it to be faster. This approach eliminates most runtime measurement variability, even on noisy CI systems and virtual machines. The final runtime Codeflash reports is the minimum total time it took to run all the test cases.
## Creating Pull Requests

View file

@ -23,24 +23,25 @@
"pages": ["index"]
},
{
"group": "🚀 Getting Started",
"pages": ["getting-started/local-installation"]
},
{
"group": "🔧 Continuous Optimization",
"group": "🚀 Quickstart",
"pages": [
"getting-started/codeflash-github-actions",
"optimizing-with-codeflash/optimize-prs",
"optimizing-with-codeflash/benchmarking"
]
},
"getting-started/local-installation"
] },
{
"group": "⚡ Optimization Workflows",
"group": "⚡ Optimizing with Codeflash",
"pages": [
"optimizing-with-codeflash/one-function",
"optimizing-with-codeflash/trace-and-optimize",
"optimizing-with-codeflash/codeflash-all",
"optimizing-with-codeflash/staging"
"optimizing-with-codeflash/codeflash-all"
]
},
{
"group": "✨ Continuous Optimization",
"pages": [
"optimizing-with-codeflash/review-optimizations",
"optimizing-with-codeflash/codeflash-github-actions",
"optimizing-with-codeflash/benchmarking"
]
},
{
@ -53,21 +54,7 @@
}
]
}
],
"global": {
"anchors": [
{
"anchor": "GitHub",
"href": "https://github.com/codeflash-ai",
"icon": "github"
},
{
"anchor": "Discord",
"href": "https://www.codeflash.ai/discord",
"icon": "discord"
}
]
}
]
},
"logo": {
"light": "/images/codeflash_light.svg",
@ -76,8 +63,14 @@
"navbar": {
"links": [
{
"label": "Contact",
"href": "mailto:contact@codeflash.ai"
"label": "Discord",
"href": "https://www.codeflash.ai/discord",
"icon": "discord"
},
{
"label": "GitHub",
"href": "https://github.com/codeflash-ai/codeflash",
"icon": "github"
},
{
"label": "Blog",
@ -102,7 +95,7 @@
"footer": {
"socials": {
"discord": "https://www.codeflash.ai/discord",
"github": "https://github.com/codeflash-ai",
"github": "https://github.com/codeflash-ai/codeflash",
"linkedin": "https://www.linkedin.com/company/codeflash-ai"
},
"links": [

View file

@ -6,8 +6,6 @@ icon: "download"
Codeflash is installed and configured on a per-project basis.
You can install Codeflash locally for a project by running the following command in the project's virtual environment:
### Prerequisites
Before installing Codeflash, ensure you have:
@ -15,7 +13,9 @@ Before installing Codeflash, ensure you have:
1. **Python 3.9 or above** installed
2. **A Python project** with a virtual environment
3. **Project dependencies installed** in your virtual environment
4. **Tests** (optional) for your code (Codeflash uses tests to verify optimizations)
Good to have (optional):
1. **Unit Tests** that Codeflash uses to ensure correctness of the optimizations
<Warning>
**Virtual Environment Required**
@ -23,14 +23,15 @@ Before installing Codeflash, ensure you have:
Always install Codeflash in your project's virtual environment, not globally. Make sure your virtual environment is activated before proceeding.
```bash
# Example: Activate your virtual environment
source venv/bin/activate # On Linux/Mac
# or
#venv\Scripts\activate # On Windows
venv\Scripts\activate # On Windows
```
</Warning>
<Steps>
<Step title="Install Codeflash">
You can install Codeflash locally for a project by running the following command in the project's virtual environment:
```bash
pip install codeflash
```
@ -39,37 +40,32 @@ pip install codeflash
**Codeflash is a Development Dependency**
We recommend installing Codeflash as a development dependency.
It doesn't need to be installed as part of your package requirements.
Codeflash is intended to be used locally and as part of development workflows such as CI.
Codeflash is intended to be used in development workflows locally and as part of CI.
Try to always use the latest version of Codeflash as it improves quickly.
<CodeGroup>
```toml Poetry
[tool.poetry.dependencies.dev]
codeflash = "^latest"
```
```bash uv
uv add --dev codeflash
```
```bash pip
pip install --dev codeflash
```bash poetry
poetry add codeflash@latest --group dev
```
</CodeGroup>
</Tip>
</Step>
<Step title="Generate a Codeflash API Key">
Codeflash uses cloud-hosted AI models to optimize your code. You'll need an API key to use it.
Codeflash uses cloud-hosted AI models and integrations with GitHub. You'll need an API key to authorize your access.
1. Visit the [Codeflash Web App](https://app.codeflash.ai/)
1. Visit the [Codeflash Web App](https://app.codeflash.ai/)
2. Sign up with your GitHub account (free)
3. Navigate to the [API Key](https://app.codeflash.ai/app/apikeys) page to generate your API key
<Note>
**Free Tier Available**
Codeflash offers a **free tier** with a limited number of optimizations per month. Perfect for trying it out or small projects!
Codeflash offers a **free tier** with a limited number of optimizations. Perfect for trying it out on small projects!
</Note>
</Step>
@ -77,10 +73,6 @@ Codeflash offers a **free tier** with a limited number of optimizations per mont
Navigate to your project's root directory (where your `pyproject.toml` file is or should be) and run:
```bash
# Make sure you're in your project root
cd /path/to/your/project
# Run the initialization
codeflash init
```
@ -89,30 +81,31 @@ If you don't have a pyproject.toml file yet, the codeflash init command will ask
<Info>
**What's pyproject.toml?**
`pyproject.toml` is a configuration file that is used to specify build tool settings for Python projects.
pyproject.toml is the modern replacement for setup.py and requirements.txt files.
It's the new standard for Python package metadata.
`pyproject.toml` is a configuration file that is used to specify build and tool settings for Python projects.
`pyproject.toml` is the modern replacement for setup.py and requirements.txt files.
</Info>
When running `codeflash init`, you will see the following prompts:
```text
1. Enter your Codeflash API key:
2. Which Python module do you want me to optimize going forward? (e.g. my_module)
3. Where are your tests located? (e.g. tests/)
4. Which test framework do you use? (pytest/unittest)
1. Enter your Codeflash API key:
2. Install the GitHub app.
3. Which Python module do you want me to optimize going forward? (e.g. my_module)
4. Where are your tests located? (e.g. tests/)
5. Which test framework do you use? (pytest/unittest)
6. Install GitHub actions for Continuous optimization?
```
</Step>
</Steps>
After you have answered these questions, Codeflash will be configured for your project.
The configuration will be saved in the `pyproject.toml` file in the root directory of your project.
To understand the configuration options, and set more advanced options, see the [Configuration](/configuration) page.
After you have answered these questions, the Codeflash configuration will be saved in the `pyproject.toml` file.
To understand the configuration options, and set more advanced options, see the [Manual Configuration](/configuration) page.
### Step 4: Install the Codeflash GitHub App
{/* TODO: Justify to users Why we need the user to install Github App even in local Installation or local optimization? */}
Finally, if you have not done so already, Codeflash will ask you to install the Github App in your repository. The Codeflash GitHub App allows access to your repository to the codeflash-ai bot to open PRs, review code, and provide optimization suggestions.
Finally, if you have not done so already, Codeflash will ask you to install the GitHub App in your repository.
The Codeflash GitHub App allows access to your repository to the codeflash-ai bot to open PRs, review code, and provide optimization suggestions.
Please [install the Codeflash GitHub
app](https://github.com/apps/codeflash-ai/installations/select_target) by choosing the repository you want to install
@ -128,30 +121,29 @@ Once configured, you can start optimizing your code immediately:
# Optimize a specific function
codeflash --file path/to/your/file.py --function function_name
# Or optimize locally without creating a PR
codeflash --file path/to/your/file.py --function function_name --no-pr
# Or optimize all functions in your codebase
codeflash --all
```
<Tip>
**Pro tip**: Start with a single function to see how Codeflash works before running it on your entire codebase.
</Tip>
</Tab>
<Tab title="Example Project">
<Card title="🚀 Try Our Example Repository" icon="github" href="https://github.com/codeflash-ai/optimize-me">
Want to see Codeflash in action? Check out our **optimize-me** repository with real examples.
<Tab title="Optimize Example Project">
<Card title="🚀 Try optimizing our example repository" icon="github" href="https://github.com/codeflash-ai/optimize-me">
Want to see Codeflash in action and don't know what code to optimize? Check out our **optimize-me** repository with code ready to optimize.
**What's included:**
- Sample Python code with performance issues
- Tests for verification
- Pre-configured `pyproject.toml`
- Before/after optimization examples in PRs
</Card>
<Steps>
<Step title="Clone the Repository">
<Step title="Fork the Repository">
Fork the [optimize-me](https://github.com/codeflash-ai/optimize-me) repo to your GitHub account by clicking "Fork" on the top of the page. This allows Codeflash to open Pull Requests with the optimizations it found on your forked repo.
</Step>
<Step title="Clone the Forked Repository">
```bash
git clone https://github.com/codeflash-ai/optimize-me.git
git clone https://github.com/your_github_username/optimize-me.git
cd optimize-me
```
</Step>
@ -159,7 +151,7 @@ cd optimize-me
<Step title="Set Up Environment">
```bash
python -m venv .venv
source .venv/bin/activate # or venv\Scripts\activate on Windows
source .venv/bin/activate
pip install -r requirements.txt
pip install codeflash
```
@ -193,7 +185,9 @@ codeflash --all # optimize the entire repo
</Accordion>
<Accordion title="🧪 No optimizations found or debugging issues">
Use the `--verbose` flag for detailed output:
Do know that not all functions can be optimized as no optimization opportunities may exist for them. This is fine and expected.
To investigate further, use the `--verbose` flag for detailed output:
```bash
codeflash optimize --verbose
```
@ -203,13 +197,7 @@ codeflash --all # optimize the entire repo
- 🚫 Why certain functions were skipped
- ⚠️ Detailed error messages
- 📊 Performance analysis results
<Tip>
**Common Reasons Functions Are Skipped:**
- Function is too simple (less than 3 lines)
- Function has no clear performance bottleneck
- Function contains unsupported constructs
</Tip>
</Accordion>
<Accordion title="🔍 No tests found errors">
@ -223,7 +211,7 @@ codeflash --all # optimize the entire repo
pytest --collect-only
# Check your pyproject.toml configuration
cat pyproject.toml | grep -A 5 "\[tool.codeflash\]"
cat pyproject.toml | grep -A 8 "\[tool.codeflash\]"
```
</Accordion>
</AccordionGroup>
@ -233,5 +221,5 @@ codeflash --all # optimize the entire repo
- Learn about [Codeflash Concepts](/codeflash-concepts/how-codeflash-works)
- Explore [Optimization workflows](/optimizing-with-codeflash/one-function)
- Set up [GitHub Actions integration](/getting-started/codeflash-github-actions)
- Set up [Pull Request Optimization](/optimizing-with-codeflash/codeflash-github-actions)
- Read [configuration options](/configuration) for advanced setups

View file

@ -6,14 +6,11 @@ sidebarTitle: "Best Practices"
keywords: ["best practices", "tips", "github actions", "tracer", "optimization", "workflow"]
---
# Getting the best out of Codeflash
Codeflash is a powerful tool; here are our recommendations based on how the Codeflash team and our customers use Codeflash.
Codeflash is a powerful tool; here are our recommendations based on how the Codeflash team uses Codeflash.
### Install the GitHub App and actions workflow
### Install the Github App and actions workflow
After you install Codeflash on an actively developed project, [installing the GitHub App](getting-started/codeflash-github-actions) and setting up the
GitHub Actions workflow will automatically optimize your code whenever new pull requests are opened. This ensures you get the best version of any changes you make to your code without any extra effort. We find that PRs are also the best time to review these changes, because the code is fresh in your mind.
After you install Codeflash on an actively developed project, [installing the GitHub Actions](optimizing-with-codeflash/codeflash-github-actions) will automatically optimize your code whenever new pull requests are opened. This ensures you get the best version of any changes you make to your code without any extra effort. We find that PRs are also the best time to review these changes, because the code is fresh in your mind.
### Find and optimize entire scripts with the Codeflash Tracer
@ -21,13 +18,18 @@ Find the best results by running [Codeflash Optimize](optimizing-with-codeflash/
This internally runs a profiler, captures inputs to all the functions your script calls, and uses those inputs to create Replay tests and benchmarks.
The optimizations you get with this method, show you how much faster your workflow will get plus guarantee that your workflow won't break if you merge in the optimizations.
###
### Find optimizations on your whole codebase with `codeflash --all`
If you have a lot of existing code, run [`codeflash --all`](optimizing-with-codeflash/codeflash-all) to discover and fix any
slow code in your project. Codeflash will open new pull requests for any optimizations it finds, and you can review and merge them at your own pace.
It is first recommended to trace your tests to achieve higher quality optimizations with this approach
```bash
codeflash optimize --trace-only -m pytest tests/ ; codeflash --all
```
### Review the PRs Codeflash opens

Binary file not shown.

After

Width:  |  Height:  |  Size: 527 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

View file

@ -1,17 +1,17 @@
---
title: "What is Codeflash?"
description: "AI-powered Python performance optimizer that automatically speeds up your code while verifying correctness"
title: "Codeflash is an AI performance optimizer for Python code"
icon: "rocket"
sidebarTitle: "Overview"
keywords: ["python", "performance", "optimization", "AI", "code analysis", "benchmarking"]
---
Codeflash speeds up Python code by figuring out the best way to rewrite your code while verifying that the behavior of the code is unchanged.
Codeflash speeds up any Python code by figuring out the best way to rewrite it while verifying that the behavior of the code is unchanged, and verifying real speed
gains through performance benchmarking.
The optimizations Codeflash finds are generally better algorithms, opportunities to remove wasteful compute, better logic, utilizing caching and utilization of more efficient library methods. Codeflash
does not modify the system architecture of your code, but it tries to find the most efficient implementation of that architecture.
does not modify the system architecture of your code, but it tries to find the most efficient implementation of your current architecture.
### Features
### How to use Codeflash
<CardGroup cols={1}>
<Card title="Optimize a Single Function" icon="bullseye" href="/optimizing-with-codeflash/one-function">
@ -21,6 +21,13 @@ does not modify the system architecture of your code, but it tries to find the m
```
</Card>
<Card title="Optimize Pull Requests" icon="code-pull-request" href="/optimizing-with-codeflash/codeflash-github-actions">
Automatically find optimizations for Pull Requests with GitHub Actions integration.
```bash
codeflash init-actions
```
</Card>
<Card title="Optimize Workflows with Tracing" icon="route" href="/optimizing-with-codeflash/trace-and-optimize">
End-to-end optimization of entire Python workflows with execution tracing.
```bash
@ -35,12 +42,7 @@ does not modify the system architecture of your code, but it tries to find the m
```
</Card>
<Card title="Optimize Pull Requests" icon="git-pull-request" href="/optimizing-with-codeflash/optimize-prs">
Automatically find optimization code changes in Pull Requests with GitHub Actions integration.
```bash
codeflash init-actions
```
</Card>
</CardGroup>
### How does Codeflash verify correctness?

View file

@ -1,21 +1,21 @@
---
title: "Using Benchmarks in CI"
title: "Optimize Performance Benchmarks with every Pull Request"
description: "Configure and use pytest-benchmark integration for performance-critical code optimization"
icon: "chart-line"
sidebarTitle: "CI Benchmarks"
sidebarTitle: Setup Benchmarks to Optimize
keywords: ["benchmarks", "CI", "pytest-benchmark", "performance testing", "github actions", "benchmark mode"]
---
<Info>
**Performance-critical optimization** - Define benchmarks for your most important functions and let Codeflash measure the real-world impact of every optimization on your performance metrics.
**Performance-critical optimization** - Define benchmarks for your most important code sections and let Codeflash optimize and measure the real-world impact of every optimization on your performance metrics.
</Info>
Benchmark mode is an easy way for users to define workflows that are performance-critical and need to be optimized.
For example, if a user has an important function that requires minimal latency, the user can define a benchmark for that function.
Codeflash will then calculate the impact (if any) of any optimization on the performance of that function.
Benchmark mode is an easy way to define workflows that are performance-critical and need to be optimized and run fast.
Codeflash will run the benchmark, understand how the current code change in the Pull Request is affecting the benchmark.
It will then try to optimize the new code for the benchmark and calculate the impact of any optimization on the speed of that benchmark.
## Using Codeflash in Benchmark Mode
1. **Create a benchmarks root**
1. **Create a benchmarks root:**
Create a directory for benchmarks if it does not already exist.
@ -31,7 +31,7 @@ Codeflash will then calculate the impact (if any) of any optimization on the per
formatter-cmds = ["disabled"]
```
2. **Define your benchmarks**
2. **Define your benchmarks:**
Currently, Codeflash only supports benchmarks written as pytest-benchmarks. Check out the [pytest-benchmark](https://pytest-benchmark.readthedocs.io/en/stable/index.html) documentation for more information on syntax.
@ -50,7 +50,7 @@ Codeflash will then calculate the impact (if any) of any optimization on the per
The pytest-benchmark format is simply used as an interface. The plugin is actually not used - Codeflash will run these benchmarks with its own pytest plugin
3. **Run Codeflash**
3. **Run and Test Codeflash:**
Run Codeflash with the `--benchmark` flag. Note that benchmark mode cannot be used with `--all`.
@ -65,13 +65,15 @@ Codeflash will then calculate the impact (if any) of any optimization on the per
```
4. **Run Codeflash in CI**
4. **Run Codeflash :**
Benchmark mode is best used together with Codeflash as a Github Action. This way, with every PR, you will know the impact of Codeflash's optimizations on your benchmarks.
Benchmark mode is best used together with Codeflash as a GitHub Action. This way,
Codeflash will trace through your benchmark and optimize the functions modified in your Pull Request to speed up the benchmark.
It will also report the impact of Codeflash's optimizations on your benchmarks.
Use `codeflash init` for an easy way to set up Codeflash as a Github Action (with the option to enable benchmark mode).
Use `codeflash init` for an easy way to set up Codeflash as a GitHub Action.
Otherwise, you can run the following command in your Codeflash GitHub Action:
After that, you can add the `--benchmark` argument to codeflash to enable benchmarks optimization.
```bash
codeflash --benchmark
@ -84,7 +86,7 @@ Codeflash will then calculate the impact (if any) of any optimization on the per
1. Codeflash identifies benchmarks in the benchmarks-root directory.
2. The benchmarks are run so that runtime statistics and information can be recorded.
2. The benchmarks are run so that runtime statistics and inputs can be recorded.
3. Replay tests are generated so the performance of optimization candidates on the exact inputs used in the benchmarks can be measured.
@ -97,5 +99,3 @@ Codeflash will then calculate the impact (if any) of any optimization on the per
Using Codeflash with benchmarks is a great way to find optimizations that really matter.
Codeflash is actively working on this feature and will be adding new capabilities in the near future!

View file

@ -1,8 +1,8 @@
---
title: "Optimize Your Entire Codebase"
description: "Automatically optimize all functions in your project with Codeflash's comprehensive analysis"
description: "Automatically optimize all codepaths in your project with Codeflash's comprehensive analysis"
icon: "database"
sidebarTitle: "Entire Codebase"
sidebarTitle: "Optimize Entire Codebase"
keywords: ["codebase optimization", "all functions", "batch optimization", "github app", "checkpoint", "recovery"]
---
@ -19,13 +19,34 @@ codeflash --all
This requires the Codeflash GitHub App to be installed in your repository.
This is a powerful feature that can help you optimize your entire codebase in one go.
This is a powerful feature that can help you optimize your entire codebase in one go. It also discovers and runs any unit tests covering the function under optimization.
Since it runs on all the functions in your codebase, it can take some time to complete, please be patient.
As this runs you will see Codeflash opening pull requests for each function it successfully optimizes.
If you only want to optimize a subdirectory you can run:
```bash
codeflash --all path/to/dir
```
<Tip>
If your project has a good number of unit tests, we can trace those to achieve higher quality results.
The following approach is recommended instead:
```bash
codeflash optimize --trace-only -m pytest tests/ ; codeflash --all
```
This will run your test suite, trace all the code covered by your tests, ensuring higher correctness guarantees
and better performance benchmarking, and help create optimizations for code where the LLMs struggle to generate and run tests.
Even though `codeflash --all` discovers any existing unit tests. It currently can only discover any test that directly calls the
function under optimization. Tracing all the tests helps ensure correctness for code that may be indirectly called by your tests.
</Tip>
## Important considerations
- **Dedicated Optimization Machine:** Optimizing the entire codebase may require considerable time—up to one day. It's recommended to allocate a dedicated machine specifically for this long-running optimization task.
- **Minimize Background Processes:** To achieve optimal results, avoid running other processes on the optimization machine. Additional processes can introduce noise into Codeflash's runtime measurements, reducing the quality of the optimizations. Although Codeflash tolerates some runtime fluctuations, minimizing noise ensures the highest optimization quality.
- **Checkpoint and Recovery:** Codeflash automatically creates checkpoints as it identifies optimizations. If the optimization process is interrupted or encounters issues, you can resume the process by re-running `codeflash --all`. The command will prompt you to continue from the most recent checkpoint.

View file

@ -1,21 +1,25 @@
---
title: "GitHub Actions Integration"
description: "Automatically optimize pull requests with Codeflash GitHub Actions workflow"
title: "Auto Optimize Pull Requests"
description: "Automatically optimize new code in pull requests with Codeflash GitHub Actions workflow"
icon: "github"
---
{/* TODO: Add more pictures to guide better */}
Codeflash can automatically optimize your code when new pull requests are opened.
Optimizing new code in Pull Requests is the best way to ensure that all code you and your team ship is performant
in the future. Automating optimization in the Pull Request stage how most teams use Codeflash, to
continuously find optimizations for their new code.
To be able to scan new code for performance optimizations, Codeflash requires a GitHub action workflow to
be installed which runs the Codeflash optimization logic on every new pull request.
To scan new code for performance optimizations, Codeflash uses a GitHub Action workflow which runs
the Codeflash optimization logic on the new code in every pull request.
If the action workflow finds an optimization, it communicates with the Codeflash GitHub
App through our secure servers and asks it to suggest new changes to the pull request.
App and asks it to suggest new changes to the pull request.
This is the most useful way of using Codeflash, where you set it up once and all your new code gets optimized.
So setting this up is highly recommended.
## Pull Request Optimization 30 seconds demo
<iframe width="640" height="400" src="https://www.youtube.com/embed/nqa-uewizkU?si=H1wb1RvPp-JqvKPh" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>
## Prerequisites
<Warning>
@ -38,18 +42,18 @@ So setting this up is highly recommended.
codeflash init-actions
```
This command will automatically create the GitHub Actions workflow file and guide you through the setup process.
Alternatively running `codeflash init` also asks to setup the github actions.
</Step>
<Step title="Test Your Setup">
<Step title="Customize and Test Your Setup">
Open a new pull request to your GitHub project. You'll see:
- ✅ A new Codeflash workflow running in GitHub Actions
- 🤖 The codeflash-ai bot commenting with optimization suggestions (if any are found)
Ensure that your Python environment installation works correctly and codeflash is able to run.
</Step>
</Steps>
<Note>
**Recommended approach** - This automated setup ensures you get the latest workflow configuration with all best practices included.
</Note>
</Tab>
<Tab title="Manual Setup">
@ -153,4 +157,16 @@ Customize the dependency installation based on your Python package manager:
</Steps>
</Tab>
</Tabs>
## How the Pull Request Optimization Suggestion looks
Codeflash creates a new dependent Pull Request for you to review with the reported speedups, helpful explanation for the optimization
and the proof of correctness. The pull request has the code change for you to review and accept.
![Codeflash PR Review](/images/codeflash_pr_suggestion_1.png)
Sometimes it also makes an inline suggestion with the optimization.
![Codeflash PR Suggestion](/images/code-suggestion.png)
We hope you enjoy the performance unlock the Pull Request optimization enables.

View file

@ -2,7 +2,7 @@
title: "Optimize a Single Function"
description: "Target and optimize individual Python functions for maximum performance gains"
icon: "bullseye"
sidebarTitle: "Single Function"
sidebarTitle: "Optimize Single Function"
keywords: ["function optimization", "single function", "class methods", "performance", "targeted optimization"]
---

View file

@ -1,33 +0,0 @@
---
title: "Optimize Pull Requests"
description: "Automatically optimize code changes in pull requests with GitHub Actions integration"
icon: "code-merge"
sidebarTitle: "PR Optimization"
keywords: ["pull requests", "github actions", "code review", "automated optimization", "dependent PR", "suggestions"]
---
<Info>
**Continuous optimization** - After initial setup, Codeflash will automatically review every new pull request and suggest performance improvements through comments and dependent PRs.
</Info>
## How to optimize a pull request
After following the setup steps in the [Automate Code Optimization with GitHub Actions](/getting-started/codeflash-github-actions) guide,
Codeflash will automatically optimize your pull requests when they are opened.
If Codeflash finds any successful optimizations, it will comment on the pull request asking you to review the changes.
![Codeflash PR Comment](/images/review-comment.png)
Codeflash can ask you to review the changes in two ways:
### Opening a dependent pull request
Codeflash will open a new pull request with the optimized code.
You can review the changes in this pull request, make changes if you want, and merge it if you are satisfied with the optimizations.
The changes will be merged back into the original pull request as a new commit.
![Codeflash PR Review](/images/dependent-pr.png)
### Reviewing the changes in the original pull request
If the suggested changes are small and only affect the modified lines, Codeflash will suggest the changes in the original pull request itself.
You can choose to accept or reject the changes directly in the original pull request.
The changes can be added to a batch of changes in the original pull request as a new commit.
![Codeflash PR Review](/images/code-suggestion.png)

View file

@ -1,26 +1,28 @@
---
title: "Trace & Optimize Workflows"
title: "Trace & Optimize E2E Workflows"
description: "End-to-end optimization of entire Python workflows with execution tracing"
icon: "route"
sidebarTitle: "Trace & Optimize"
sidebarTitle: "Optimize E2E Workflows"
keywords: ["tracing", "workflow optimization", "replay tests", "end-to-end", "script optimization", "context manager"]
---
Codeflash supports optimizing an entire Python script end-to-end by tracing the script's execution and generating Replay Tests. Tracing follows the execution of a script, profiles it and captures inputs to all called functions, allowing them to be replayed during optimization. Codeflash uses these Replay Tests to optimize all functions called in the script, starting from the most important ones.
Codeflash supports optimizing an entire Python script end-to-end by tracing the script's execution and generating Replay Tests.
Tracing follows the execution of a script, profiles it and captures inputs to all functions it called, allowing them to be replayed during optimization.
Codeflash uses these Replay Tests to optimize all functions called in the script, starting from the most important ones.
To optimize a script, `python myscript.py`, replace `python` with `codeflash optimize` and run the following command:
To optimize a script, `python myscript.py`, simply replace `python` with `codeflash optimize` and run the following command:
```bash
codeflash optimize myscript.py
```
To optimize code called by pytest tests that you could normally run like `python -m pytest tests/`, use this command:
You can also optimize code called by pytest tests that you could normally run like `python -m pytest tests/`, this provides for a good workload to optimize. Run this command:
```bash
codeflash optimize -m pytest tests/
```
This powerful command creates high-quality optimizations, making it ideal when you need to optimize a workflow or script. The initial tracing process can be slow, so try to limit your script's runtime to under 1 minute for best results. If your workflow is longer, consider tracing it into smaller sections by using the Codeflash tracer as a context manager (point 3 below).
The powerful `codeflash optimize` command creates high-quality optimizations, making it ideal when you need to optimize a workflow or script. The initial tracing process can be slow, so try to limit your script's runtime to under 1 minute for best results. If your workflow is longer, consider tracing it into smaller sections by using the Codeflash tracer as a context manager (point 3 below).
The generated replay tests and the trace file are for the immediate optimization use, don't add them to git.
@ -32,10 +34,11 @@ The generated replay tests and the trace file are for the immediate optimization
## What is the codeflash optimize command?
`codeflash optimize` tries to do everything that an expert engineer would do while optimizing a workflow. It profiles your code, traces the execution of your workflow and generates a set of test cases that are derived from how your code is actually run.
Codeflash Tracer works by recording the inputs of your functions as they are called in your codebase. These inputs are then used to generate test cases that are representative of the real-world usage of your functions.
Codeflash Tracer works by recording the inputs of your functions as they are called in your codebase, and generating
regression tests with those inputs.
We call these generated test cases "Replay Tests" because they replay the inputs that were recorded during the tracing phase.
These replay tests are representative of the real-world usage of your functions.
Then, Codeflash Optimizer can use these replay tests to verify correctness and calculate accurate performance gains for the optimized functions.
Using Replay Tests, Codeflash can verify that the optimized functions produce the same output as the original function and also measure the performance gains of the optimized function on the real-world inputs.
This way you can be *sure* that the optimized function causes no changes of behavior for the traced workflow and also, that it is faster than the original function. To get more confidence on the correctness of the code, we also generate several LLM generated test cases and discover any existing unit cases you may have.
@ -57,15 +60,16 @@ Codeflash script optimizer can be used in three ways:
codeflash optimize path/to/your/file.py --your_options
```
The above command should suffice in most situations. You can add a argument like `codeflash optimize -o trace_file_path.trace` if you want to customize the trace file location. Otherwise, it defaults to `codeflash.trace` in the current working directory.
The above command should suffice in most situations.
To customize the trace file location you can specify it like `codeflash optimize -o trace_file_path.trace`. Otherwise, it defaults to `codeflash.trace` in the current working directory.
2. **Trace and optimize as two separate steps**
If you want more control over the tracing and optimization process. You can trace first and then optimize with the replay tests later. Each replay test is associated with a trace file.
To first create just the trace file, run
To create just the trace file first, run
```python
```bash
codeflash optimize -o trace_file.trace --trace-only path/to/your/file.py --your_options
```
@ -79,7 +83,7 @@ Codeflash script optimizer can be used in three ways:
- `--tracer-timeout`: The maximum time in seconds to trace the entire workflow. Default is indefinite. This is useful while tracing really long workflows.
3. **As a Context Manager -**
To trace only very specific sections of your codeflash, You can also use the Codeflash Tracer as a context manager.
To trace only specific sections of your code, You can also use the Codeflash Tracer as a context manager.
You can wrap the code you want to trace in a `with` statement as follows -
```python
@ -89,7 +93,7 @@ Codeflash script optimizer can be used in three ways:
model.predict() # Your code here
```
This is much faster than tracing the whole script. Sometimes, if tracing the whole script fails, then the Context Manager can also be used to trace the code sections.
This is much faster than tracing the whole script. It can also help if tracing the whole script fails.
After this finishes, you can optimize using the generated replay tests.
@ -97,7 +101,7 @@ Codeflash script optimizer can be used in three ways:
codeflash --replay-test /path/to/test_replay_test_0.py
```
More Options for the Tracer:
More Options for the Tracer Context Manager:
- `disable`: If set to `True`, the tracer will not trace the code. Default is `False`.
- `max_function_count`: The maximum number of times to trace a single function. More calls to a function will not be traced. Default is 100.

View file

@ -200,7 +200,7 @@ target-version = "py39"
line-length = 120
fix = true
show-fixes = true
exclude = ["code_to_optimize/", "pie_test_set/", "tests/"]
extend-exclude = ["code_to_optimize/", "pie_test_set/", "tests/"]
[tool.ruff.lint]
select = ["ALL"]

View file

@ -1,6 +1,7 @@
from pathlib import Path
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
def test_add_needed_imports_from_module0() -> None:
@ -121,3 +122,230 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
project_root = Path("/home/roger/repos/codeflash")
new_module = add_needed_imports_from_module(src_module, dst_module, src_path, dst_path, project_root)
assert new_module == expected
def test_duplicated_imports() -> None:
optim_code = '''from dataclasses import dataclass
from recce.adapter.base import BaseAdapter
from typing import Dict, List, Optional
@dataclass
class DbtAdapter(BaseAdapter):
def build_parent_map(self, nodes: Dict, base: Optional[bool] = False) -> Dict[str, List[str]]:
manifest = self.curr_manifest if base is False else self.base_manifest
try:
parent_map_source = manifest.parent_map
except AttributeError:
parent_map_source = manifest.to_dict()["parent_map"]
node_ids = set(nodes)
parent_map = {}
for k, parents in parent_map_source.items():
if k not in node_ids:
continue
parent_map[k] = [parent for parent in parents if parent in node_ids]
return parent_map
'''
original_code = '''import json
import logging
import os
import uuid
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, fields
from errno import ENOENT
from functools import lru_cache
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Set,
Tuple,
Type,
Union,
)
from recce.event import log_performance
from recce.exceptions import RecceException
from recce.util.cll import CLLPerformanceTracking, cll
from recce.util.lineage import (
build_column_key,
filter_dependency_maps,
find_downstream,
find_upstream,
)
from recce.util.perf_tracking import LineagePerfTracker
from ...tasks.profile import ProfileTask
from ...util.breaking import BreakingPerformanceTracking, parse_change_category
try:
import agate
import dbt.adapters.factory
from dbt.contracts.state import PreviousState
except ImportError as e:
print("Error: dbt module not found. Please install it by running:")
print("pip install dbt-core dbt-<adapter>")
raise e
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer
from recce.adapter.base import BaseAdapter
from recce.state import ArtifactsRoot
from ...models import RunType
from ...models.types import (
CllColumn,
CllData,
CllNode,
LineageDiff,
NodeChange,
NodeDiff,
)
from ...tasks import (
HistogramDiffTask,
ProfileDiffTask,
QueryBaseTask,
QueryDiffTask,
QueryTask,
RowCountDiffTask,
RowCountTask,
Task,
TopKDiffTask,
ValueDiffDetailTask,
ValueDiffTask,
)
from .dbt_version import DbtVersion
@dataclass
class DbtAdapter(BaseAdapter):
def build_parent_map(self, nodes: Dict, base: Optional[bool] = False) -> Dict[str, List[str]]:
manifest = self.curr_manifest if base is False else self.base_manifest
manifest_dict = manifest.to_dict()
node_ids = nodes.keys()
parent_map = {}
for k, parents in manifest_dict["parent_map"].items():
if k not in node_ids:
continue
parent_map[k] = [parent for parent in parents if parent in node_ids]
return parent_map
'''
expected = '''import json
import logging
import os
import uuid
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, fields
from errno import ENOENT
from functools import lru_cache
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Set,
Tuple,
Type,
Union,
)
from recce.event import log_performance
from recce.exceptions import RecceException
from recce.util.cll import CLLPerformanceTracking, cll
from recce.util.lineage import (
build_column_key,
filter_dependency_maps,
find_downstream,
find_upstream,
)
from recce.util.perf_tracking import LineagePerfTracker
from ...tasks.profile import ProfileTask
from ...util.breaking import BreakingPerformanceTracking, parse_change_category
try:
import agate
import dbt.adapters.factory
from dbt.contracts.state import PreviousState
except ImportError as e:
print("Error: dbt module not found. Please install it by running:")
print("pip install dbt-core dbt-<adapter>")
raise e
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer
from recce.adapter.base import BaseAdapter
from recce.state import ArtifactsRoot
from ...models import RunType
from ...models.types import (
CllColumn,
CllData,
CllNode,
LineageDiff,
NodeChange,
NodeDiff,
)
from ...tasks import (
HistogramDiffTask,
ProfileDiffTask,
QueryBaseTask,
QueryDiffTask,
QueryTask,
RowCountDiffTask,
RowCountTask,
Task,
TopKDiffTask,
ValueDiffDetailTask,
ValueDiffTask,
)
from .dbt_version import DbtVersion
@dataclass
class DbtAdapter(BaseAdapter):
def build_parent_map(self, nodes: Dict, base: Optional[bool] = False) -> Dict[str, List[str]]:
manifest = self.curr_manifest if base is False else self.base_manifest
try:
parent_map_source = manifest.parent_map
except AttributeError:
parent_map_source = manifest.to_dict()["parent_map"]
node_ids = set(nodes)
parent_map = {}
for k, parents in parent_map_source.items():
if k not in node_ids:
continue
parent_map[k] = [parent for parent in parents if parent in node_ids]
return parent_map
'''
function_name: str = "DbtAdapter.build_parent_map"
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
optimized_code=optim_code,
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected

View file

@ -1,4 +1,5 @@
from __future__ import annotations
import re
import libcst as cst
from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder, AddRequestArgument
import dataclasses
@ -3091,3 +3092,139 @@ def my_fixture(request):
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_type_checking_imports():
optim_code = """from dataclasses import dataclass
from pydantic_ai.providers import Provider, infer_provider
from pydantic_ai_slim.pydantic_ai.models import Model
from pydantic_ai_slim.pydantic_ai.tools import ToolDefinition
from typing import Literal
#### problamatic imports ####
from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool
import requests
import aiohttp as aiohttp_
from math import pi as PI, sin as sine
@dataclass(init=False)
class HuggingFaceModel(Model):
def __init__(
self,
model_name: str,
*,
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
):
print(requests.__name__)
print(aiohttp_.__name__)
print(PI)
print(sine)
# Fast branch: avoid repeating provider assignment
if isinstance(provider, str):
provider_obj = infer_provider(provider)
else:
provider_obj = provider
self._provider = provider
self._model_name = model_name
self.client = provider_obj.client
@staticmethod
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
# Inline dict creation and single pass for possible strict attribute
tool_dict = {
'type': 'function',
'function': {
'name': f.name,
'description': f.description,
'parameters': f.parameters_json_schema,
},
}
if f.strict is not None:
tool_dict['function']['strict'] = f.strict
return ChatCompletionInputTool.parse_obj_as_instance(tool_dict) # type: ignore
"""
original_code = """from dataclasses import dataclass
from pydantic_ai.providers import Provider, infer_provider
from pydantic_ai_slim.pydantic_ai.models import Model
from pydantic_ai_slim.pydantic_ai.tools import ToolDefinition
from typing import Literal
try:
import aiohttp as aiohttp_
from math import pi as PI, sin as sine
from huggingface_hub import (
AsyncInferenceClient,
ChatCompletionInputMessage,
ChatCompletionInputMessageChunk,
ChatCompletionInputTool,
ChatCompletionInputToolCall,
ChatCompletionInputURL,
ChatCompletionOutput,
ChatCompletionOutputMessage,
ChatCompletionStreamOutput,
)
from huggingface_hub.errors import HfHubHTTPError
except ImportError as _import_error:
raise ImportError(
'Please install `huggingface_hub` to use Hugging Face Inference Providers, '
'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`'
) from _import_error
if True:
import requests
__all__ = (
'HuggingFaceModel',
'HuggingFaceModelSettings',
)
@dataclass(init=False)
class HuggingFaceModel(Model):
def __init__(
self,
model_name: str,
*,
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
):
self._model_name = model_name
self._provider = provider
if isinstance(provider, str):
provider = infer_provider(provider)
self.client = provider.client
@staticmethod
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore
{
'type': 'function',
'function': {
'name': f.name,
'description': f.description,
'parameters': f.parameters_json_schema,
},
}
)
if f.strict is not None:
tool_param['function']['strict'] = f.strict
return tool_param
"""
function_name: str = "HuggingFaceModel._map_tool_definition"
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
optimized_code=optim_code,
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert not re.search(r"^import requests\b", new_code, re.MULTILINE) # conditional simple import: import <name>
assert not re.search(r"^import aiohttp as aiohttp_\b", new_code, re.MULTILINE) # conditional alias import: import <name> as <alias>
assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports
assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import

View file

@ -357,7 +357,7 @@ def test_cleanup_paths(multiple_existing_and_non_existing_files: list[Path]) ->
def test_generate_candidates() -> None:
source_code_path = Path("/Users/krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py")
expected_candidates = [
expected_candidates = {
"coverage_utils.py",
"code_utils/coverage_utils.py",
"codeflash/code_utils/coverage_utils.py",
@ -367,7 +367,8 @@ def test_generate_candidates() -> None:
"Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py",
"krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py",
"Users/krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py",
]
"/Users/krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py"
}
assert generate_candidates(source_code_path) == expected_candidates

View file

@ -208,13 +208,15 @@ import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')"""
expected = original_code
with tempfile.NamedTemporaryFile("w") as tmp:
tmp.write(original_code)
tmp.flush()
tmp_path = tmp.name
with pytest.raises(FileNotFoundError):
format_code(formatter_cmds=["exit 1"], path=Path(tmp_path))
try:
new_code = format_code(formatter_cmds=["exit 1"], path=Path(tmp_path), exit_on_failure=False)
assert new_code == original_code
except Exception as e:
assert False, f"Shouldn't throw an exception even if the formatter is not found: {e}"
def _run_formatting_test(source_code: str, should_content_change: bool, expected = None, optimized_function: str = ""):
@ -570,12 +572,12 @@ if __name__=='__main__':main()
def test_formatting_edge_case_exactly_100_diffs():
"""Test behavior when exactly at the threshold of 100 changes."""
# Create a file with exactly 100 minor formatting issues
source_code = '''import json\n''' + '''
def func{}():
snippet = '''import json\n''' + '''
def func_{i}():
x=1;y=2;z=3
return x+y+z
'''.replace('{}', '_{i}').format(i='{i}') * 33 # This creates exactly 100 potential formatting fixes
'''
source_code = "".join([snippet.format(i=i) for i in range(100)])
_run_formatting_test(source_code, False)