saving the API key correctly for windows PowerShell (#940)

* remove emoji

* save cf api key correctly for powershell

* fix test_shell_utils test

* fix linting and tests

* FIX ALL TESTS

* revert tests/test_trace_benchmarks.py

---------

Co-authored-by: HeshamHM28 <HeshamMohamedFathy@outlook.com>
Co-authored-by: Kevin Turcios <106575910+KRRT7@users.noreply.github.com>
This commit is contained in:
mashraf-222 2025-11-26 09:02:31 +02:00 committed by GitHub
parent 024ef1a680
commit 63b6e77b7f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 202 additions and 39 deletions

View file

@ -32,7 +32,7 @@ from codeflash.code_utils.env_utils import check_formatter_installed, get_codefl
from codeflash.code_utils.git_utils import get_git_remotes, get_repo_owner_and_name
from codeflash.code_utils.github_utils import get_github_secrets_page_url
from codeflash.code_utils.oauth_handler import perform_oauth_signin
from codeflash.code_utils.shell_utils import get_shell_rc_path, save_api_key_to_rc
from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell, save_api_key_to_rc
from codeflash.either import is_successful
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.telemetry.posthog_cf import ph
@ -136,7 +136,10 @@ def init_codeflash() -> None:
completion_message += (
"\n\n🐚 Don't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
)
reload_cmd = f"call {get_shell_rc_path()}" if os.name == "nt" else f"source {get_shell_rc_path()}"
if os.name == "nt":
reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}"
else:
reload_cmd = f"source {get_shell_rc_path()}"
completion_message += f"\nOr run: {reload_cmd}"
completion_panel = Panel(
@ -1087,7 +1090,7 @@ def configure_pyproject_toml(
with toml_path.open("w", encoding="utf8") as pyproject_file:
pyproject_file.write(tomlkit.dumps(pyproject_data))
click.echo(f"Added Codeflash configuration to {toml_path}")
click.echo(f"Added Codeflash configuration to {toml_path}")
click.echo()
return True
@ -1264,7 +1267,8 @@ def enter_api_key_and_save_to_rc() -> None:
browser_launched = True # This does not work on remote consoles
shell_rc_path = get_shell_rc_path()
if not shell_rc_path.exists() and os.name == "nt":
# On Windows, create a batch file in the user's home directory (not auto-run, just used to store api key)
# On Windows, create the appropriate file (PowerShell .ps1 or CMD .bat) in the user's home directory
shell_rc_path.parent.mkdir(parents=True, exist_ok=True)
shell_rc_path.touch()
click.echo(f"✅ Created {shell_rc_path}")
get_user_id(api_key=api_key) # Used to verify whether the API key is valid.

View file

@ -59,17 +59,26 @@ def get_codeflash_api_key() -> str:
# Check environment variable first
env_api_key = os.environ.get("CODEFLASH_API_KEY")
shell_api_key = read_api_key_from_shell_config()
logger.debug(
f"env_utils.py:get_codeflash_api_key - env_api_key: {'***' + env_api_key[-4:] if env_api_key else None}, shell_api_key: {'***' + shell_api_key[-4:] if shell_api_key else None}"
)
# If we have an env var but it's not in shell config, save it for persistence
if env_api_key and not shell_api_key:
try:
from codeflash.either import is_successful
logger.debug("env_utils.py:get_codeflash_api_key - Saving API key from environment to shell config")
result = save_api_key_to_rc(env_api_key)
if is_successful(result):
logger.debug(f"Automatically saved API key from environment to shell config: {result.unwrap()}")
logger.debug(
f"env_utils.py:get_codeflash_api_key - Automatically saved API key from environment to shell config: {result.unwrap()}"
)
else:
logger.debug(f"env_utils.py:get_codeflash_api_key - Failed to save API key: {result.failure()}")
except Exception as e:
logger.debug(f"Failed to automatically save API key to shell config: {e}")
logger.debug(
f"env_utils.py:get_codeflash_api_key - Failed to automatically save API key to shell config: {e}"
)
# Prefer the shell configuration over environment variables for lsp,
# as the API key may change in the RC file during lsp runtime. Since the LSP client (extension) can restart

View file

@ -1,40 +1,107 @@
from __future__ import annotations
import contextlib
import os
import re
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.compat import LF
from codeflash.either import Failure, Success
if TYPE_CHECKING:
from codeflash.either import Result
if os.name == "nt": # Windows
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
SHELL_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
else:
SHELL_RC_EXPORT_PATTERN = re.compile(
r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE
)
SHELL_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
# PowerShell patterns and prefixes
POWERSHELL_RC_EXPORT_PATTERN = re.compile(
r'^\$env:CODEFLASH_API_KEY\s*=\s*(?:"|\')?(cf-[^\s"\']+)(?:"|\')?\s*$', re.MULTILINE
)
POWERSHELL_RC_EXPORT_PREFIX = "$env:CODEFLASH_API_KEY = "
# CMD/Batch patterns and prefixes
CMD_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
CMD_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
# Unix shell patterns and prefixes
UNIX_RC_EXPORT_PATTERN = re.compile(r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE)
UNIX_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
def is_powershell() -> bool:
"""Detect if we're running in PowerShell on Windows.
Uses multiple heuristics:
1. PSModulePath environment variable (PowerShell always sets this)
2. COMSPEC pointing to powershell.exe
3. TERM_PROGRAM indicating Windows Terminal (often uses PowerShell)
"""
if os.name != "nt":
return False
# Primary check: PSMODULEPATH is set by PowerShell
# This is the most reliable indicator as PowerShell always sets this
ps_module_path = os.environ.get("PSMODULEPATH")
if ps_module_path:
logger.debug("shell_utils.py:is_powershell - Detected PowerShell via PSModulePath")
return True
# Secondary check: COMSPEC points to PowerShell
comspec = os.environ.get("COMSPEC", "").lower()
if "powershell" in comspec:
logger.debug(f"shell_utils.py:is_powershell - Detected PowerShell via COMSPEC: {comspec}")
return True
# Tertiary check: Windows Terminal often uses PowerShell by default
# But we only use this if other indicators are ambiguous
term_program = os.environ.get("TERM_PROGRAM", "").lower()
# Check if we can find evidence of CMD (cmd.exe in COMSPEC)
# If not, assume PowerShell for Windows Terminal
if "windows" in term_program and "terminal" in term_program and "cmd.exe" not in comspec:
logger.debug(f"shell_utils.py:is_powershell - Detected PowerShell via Windows Terminal (COMSPEC: {comspec})")
return True
logger.debug(f"shell_utils.py:is_powershell - Not PowerShell (COMSPEC: {comspec}, TERM_PROGRAM: {term_program})")
return False
def read_api_key_from_shell_config() -> Optional[str]:
"""Read API key from shell configuration file."""
shell_rc_path = get_shell_rc_path()
# Ensure shell_rc_path is a Path object for consistent handling
if not isinstance(shell_rc_path, Path):
shell_rc_path = Path(shell_rc_path)
# Determine the correct pattern to use based on the file extension and platform
if os.name == "nt": # Windows
pattern = POWERSHELL_RC_EXPORT_PATTERN if shell_rc_path.suffix == ".ps1" else CMD_RC_EXPORT_PATTERN
else: # Unix-like
pattern = UNIX_RC_EXPORT_PATTERN
try:
shell_rc_path = get_shell_rc_path()
with open(shell_rc_path, encoding="utf8") as shell_rc: # noqa: PTH123
# Convert Path to string using as_posix() for cross-platform path compatibility
shell_rc_path_str = shell_rc_path.as_posix() if isinstance(shell_rc_path, Path) else str(shell_rc_path)
with open(shell_rc_path_str, encoding="utf8") as shell_rc: # noqa: PTH123
shell_contents = shell_rc.read()
matches = SHELL_RC_EXPORT_PATTERN.findall(shell_contents)
return matches[-1] if matches else None
matches = pattern.findall(shell_contents)
if matches:
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - Found API key in file: {shell_rc_path}")
return matches[-1]
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - No API key found in file: {shell_rc_path}")
return None
except FileNotFoundError:
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - File not found: {shell_rc_path}")
return None
except Exception as e:
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - Error reading file: {e}")
return None
def get_shell_rc_path() -> Path:
"""Get the path to the user's shell configuration file."""
if os.name == "nt": # on Windows, we use a batch file in the user's home directory
if os.name == "nt": # Windows
if is_powershell():
return Path.home() / "codeflash_env.ps1"
return Path.home() / "codeflash_env.bat"
shell = os.environ.get("SHELL", "/bin/bash").split("/")[-1]
shell_rc_filename = {"zsh": ".zshrc", "ksh": ".kshrc", "csh": ".cshrc", "tcsh": ".cshrc", "dash": ".profile"}.get(
@ -44,40 +111,123 @@ def get_shell_rc_path() -> Path:
def get_api_key_export_line(api_key: str) -> str:
return f'{SHELL_RC_EXPORT_PREFIX}"{api_key}"'
"""Get the appropriate export line based on the shell type."""
if os.name == "nt": # Windows
if is_powershell():
return f'{POWERSHELL_RC_EXPORT_PREFIX}"{api_key}"'
return f'{CMD_RC_EXPORT_PREFIX}"{api_key}"'
# Unix-like
return f'{UNIX_RC_EXPORT_PREFIX}"{api_key}"'
def save_api_key_to_rc(api_key: str) -> Result[str, str]:
"""Save API key to the appropriate shell configuration file."""
shell_rc_path = get_shell_rc_path()
# Ensure shell_rc_path is a Path object for consistent handling
if not isinstance(shell_rc_path, Path):
shell_rc_path = Path(shell_rc_path)
api_key_line = get_api_key_export_line(api_key)
try:
with open(shell_rc_path, "r+", encoding="utf8") as shell_file: # noqa: PTH123
shell_contents = shell_file.read()
if os.name == "nt" and not shell_contents: # on windows we're writing to a batch file
shell_contents = "@echo off"
existing_api_key = read_api_key_from_shell_config()
if existing_api_key:
# Replace the existing API key line
updated_shell_contents = re.sub(SHELL_RC_EXPORT_PATTERN, api_key_line, shell_contents)
action = "Updated CODEFLASH_API_KEY in"
else:
logger.debug(f"shell_utils.py:save_api_key_to_rc - Saving API key to: {shell_rc_path}")
logger.debug(f"shell_utils.py:save_api_key_to_rc - API key line format: {api_key_line[:30]}...")
# Determine the correct pattern to use for replacement
if os.name == "nt": # Windows
if is_powershell():
pattern = POWERSHELL_RC_EXPORT_PATTERN
logger.debug("shell_utils.py:save_api_key_to_rc - Using PowerShell pattern")
else:
pattern = CMD_RC_EXPORT_PATTERN
logger.debug("shell_utils.py:save_api_key_to_rc - Using CMD pattern")
else: # Unix-like
pattern = UNIX_RC_EXPORT_PATTERN
logger.debug("shell_utils.py:save_api_key_to_rc - Using Unix pattern")
try:
# Create directory if it doesn't exist (ignore errors - file operation will fail if needed)
# Directory creation failed, but we'll still try to open the file
# The file operation itself will raise the appropriate exception if there are permission issues
with contextlib.suppress(OSError, PermissionError):
shell_rc_path.parent.mkdir(parents=True, exist_ok=True)
# Convert Path to string using as_posix() for cross-platform path compatibility
shell_rc_path_str = shell_rc_path.as_posix() if isinstance(shell_rc_path, Path) else str(shell_rc_path)
# Try to open in r+ mode (read and write in single operation)
# Handle FileNotFoundError if file doesn't exist (r+ requires file to exist)
try:
with open(shell_rc_path_str, "r+", encoding="utf8") as shell_file: # noqa: PTH123
shell_contents = shell_file.read()
logger.debug(f"shell_utils.py:save_api_key_to_rc - Read existing file, length: {len(shell_contents)}")
# Initialize empty file with header for batch files if needed
if not shell_contents:
logger.debug("shell_utils.py:save_api_key_to_rc - File is empty, initializing")
if os.name == "nt" and not is_powershell():
shell_contents = "@echo off"
logger.debug("shell_utils.py:save_api_key_to_rc - Added @echo off header for batch file")
# Check if API key already exists in the current file
matches = pattern.findall(shell_contents)
existing_in_file = bool(matches)
logger.debug(f"shell_utils.py:save_api_key_to_rc - Existing key in file: {existing_in_file}")
if existing_in_file:
# Replace the existing API key line in this file
updated_shell_contents = re.sub(pattern, api_key_line, shell_contents)
action = "Updated CODEFLASH_API_KEY in"
logger.debug("shell_utils.py:save_api_key_to_rc - Replaced existing API key")
else:
# Append the new API key line
if shell_contents and not shell_contents.endswith(LF):
updated_shell_contents = shell_contents + LF + api_key_line + LF
else:
updated_shell_contents = shell_contents.rstrip() + f"{LF}{api_key_line}{LF}"
action = "Added CODEFLASH_API_KEY to"
logger.debug("shell_utils.py:save_api_key_to_rc - Appended new API key")
# Write the updated contents
shell_file.seek(0)
shell_file.write(updated_shell_contents)
shell_file.truncate()
except FileNotFoundError:
# File doesn't exist, create it first with initial content
logger.debug("shell_utils.py:save_api_key_to_rc - File does not exist, creating new")
shell_contents = ""
# Initialize with header for batch files if needed
if os.name == "nt" and not is_powershell():
shell_contents = "@echo off"
logger.debug("shell_utils.py:save_api_key_to_rc - Added @echo off header for batch file")
# Create the file by opening in write mode
with open(shell_rc_path_str, "w", encoding="utf8") as shell_file: # noqa: PTH123
shell_file.write(shell_contents)
# Re-open in r+ mode to add the API key (r+ allows both read and write)
with open(shell_rc_path_str, "r+", encoding="utf8") as shell_file: # noqa: PTH123
# Append the new API key line
updated_shell_contents = shell_contents.rstrip() + f"{LF}{api_key_line}{LF}"
action = "Added CODEFLASH_API_KEY to"
logger.debug("shell_utils.py:save_api_key_to_rc - Appended new API key to new file")
# Write the updated contents
shell_file.seek(0)
shell_file.write(updated_shell_contents)
shell_file.truncate()
logger.debug(f"shell_utils.py:save_api_key_to_rc - Successfully wrote to {shell_rc_path}")
shell_file.seek(0)
shell_file.write(updated_shell_contents)
shell_file.truncate()
return Success(f"{action} {shell_rc_path}")
except PermissionError:
except PermissionError as e:
logger.debug(f"shell_utils.py:save_api_key_to_rc - Permission error: {e}")
return Failure(
f"💡 I tried adding your Codeflash API key to {shell_rc_path} - but seems like I don't have permissions to do so.{LF}"
f"You'll need to open it yourself and add the following line:{LF}{LF}{api_key_line}{LF}"
)
except FileNotFoundError:
except Exception as e:
logger.debug(f"shell_utils.py:save_api_key_to_rc - Error: {e}")
return Failure(
f"💡 I went to save your Codeflash API key to {shell_rc_path}, but noticed that it doesn't exist.{LF}"
f"💡 I went to save your Codeflash API key to {shell_rc_path}, but encountered an error: {e}{LF}"
f"To ensure your Codeflash API key is automatically loaded into your environment at startup, you can create {shell_rc_path} and add the following line:{LF}"
f"{LF}{api_key_line}{LF}"
)

View file

@ -15,7 +15,7 @@ class TestShellUtils(unittest.TestCase):
api_key = "cf-12345"
result = save_api_key_to_rc(api_key)
self.assertTrue(isinstance(result, Success))
mock_file.assert_called_with("/fake/path/.bashrc", encoding="utf8")
mock_file.assert_called_with("/fake/path/.bashrc", "r+", encoding="utf8")
handle = mock_file()
handle.write.assert_called_once()
handle.truncate.assert_called_once()