mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
saving the API key correctly for windows PowerShell (#940)
* remove emoji * save cf api key correctly for powershell * fix test_shell_utils test * fix linting and tests * FIX ALL TESTS * revert tests/test_trace_benchmarks.py --------- Co-authored-by: HeshamHM28 <HeshamMohamedFathy@outlook.com> Co-authored-by: Kevin Turcios <106575910+KRRT7@users.noreply.github.com>
This commit is contained in:
parent
024ef1a680
commit
63b6e77b7f
4 changed files with 202 additions and 39 deletions
|
|
@ -32,7 +32,7 @@ from codeflash.code_utils.env_utils import check_formatter_installed, get_codefl
|
|||
from codeflash.code_utils.git_utils import get_git_remotes, get_repo_owner_and_name
|
||||
from codeflash.code_utils.github_utils import get_github_secrets_page_url
|
||||
from codeflash.code_utils.oauth_handler import perform_oauth_signin
|
||||
from codeflash.code_utils.shell_utils import get_shell_rc_path, save_api_key_to_rc
|
||||
from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell, save_api_key_to_rc
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
|
|
@ -136,7 +136,10 @@ def init_codeflash() -> None:
|
|||
completion_message += (
|
||||
"\n\n🐚 Don't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
|
||||
)
|
||||
reload_cmd = f"call {get_shell_rc_path()}" if os.name == "nt" else f"source {get_shell_rc_path()}"
|
||||
if os.name == "nt":
|
||||
reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}"
|
||||
else:
|
||||
reload_cmd = f"source {get_shell_rc_path()}"
|
||||
completion_message += f"\nOr run: {reload_cmd}"
|
||||
|
||||
completion_panel = Panel(
|
||||
|
|
@ -1087,7 +1090,7 @@ def configure_pyproject_toml(
|
|||
|
||||
with toml_path.open("w", encoding="utf8") as pyproject_file:
|
||||
pyproject_file.write(tomlkit.dumps(pyproject_data))
|
||||
click.echo(f"✅ Added Codeflash configuration to {toml_path}")
|
||||
click.echo(f"Added Codeflash configuration to {toml_path}")
|
||||
click.echo()
|
||||
return True
|
||||
|
||||
|
|
@ -1264,7 +1267,8 @@ def enter_api_key_and_save_to_rc() -> None:
|
|||
browser_launched = True # This does not work on remote consoles
|
||||
shell_rc_path = get_shell_rc_path()
|
||||
if not shell_rc_path.exists() and os.name == "nt":
|
||||
# On Windows, create a batch file in the user's home directory (not auto-run, just used to store api key)
|
||||
# On Windows, create the appropriate file (PowerShell .ps1 or CMD .bat) in the user's home directory
|
||||
shell_rc_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shell_rc_path.touch()
|
||||
click.echo(f"✅ Created {shell_rc_path}")
|
||||
get_user_id(api_key=api_key) # Used to verify whether the API key is valid.
|
||||
|
|
|
|||
|
|
@ -59,17 +59,26 @@ def get_codeflash_api_key() -> str:
|
|||
# Check environment variable first
|
||||
env_api_key = os.environ.get("CODEFLASH_API_KEY")
|
||||
shell_api_key = read_api_key_from_shell_config()
|
||||
|
||||
logger.debug(
|
||||
f"env_utils.py:get_codeflash_api_key - env_api_key: {'***' + env_api_key[-4:] if env_api_key else None}, shell_api_key: {'***' + shell_api_key[-4:] if shell_api_key else None}"
|
||||
)
|
||||
# If we have an env var but it's not in shell config, save it for persistence
|
||||
if env_api_key and not shell_api_key:
|
||||
try:
|
||||
from codeflash.either import is_successful
|
||||
|
||||
logger.debug("env_utils.py:get_codeflash_api_key - Saving API key from environment to shell config")
|
||||
result = save_api_key_to_rc(env_api_key)
|
||||
if is_successful(result):
|
||||
logger.debug(f"Automatically saved API key from environment to shell config: {result.unwrap()}")
|
||||
logger.debug(
|
||||
f"env_utils.py:get_codeflash_api_key - Automatically saved API key from environment to shell config: {result.unwrap()}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"env_utils.py:get_codeflash_api_key - Failed to save API key: {result.failure()}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to automatically save API key to shell config: {e}")
|
||||
logger.debug(
|
||||
f"env_utils.py:get_codeflash_api_key - Failed to automatically save API key to shell config: {e}"
|
||||
)
|
||||
|
||||
# Prefer the shell configuration over environment variables for lsp,
|
||||
# as the API key may change in the RC file during lsp runtime. Since the LSP client (extension) can restart
|
||||
|
|
|
|||
|
|
@ -1,40 +1,107 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.compat import LF
|
||||
from codeflash.either import Failure, Success
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.either import Result
|
||||
|
||||
if os.name == "nt": # Windows
|
||||
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
|
||||
SHELL_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
|
||||
else:
|
||||
SHELL_RC_EXPORT_PATTERN = re.compile(
|
||||
r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE
|
||||
)
|
||||
SHELL_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
|
||||
# PowerShell patterns and prefixes
|
||||
POWERSHELL_RC_EXPORT_PATTERN = re.compile(
|
||||
r'^\$env:CODEFLASH_API_KEY\s*=\s*(?:"|\')?(cf-[^\s"\']+)(?:"|\')?\s*$', re.MULTILINE
|
||||
)
|
||||
POWERSHELL_RC_EXPORT_PREFIX = "$env:CODEFLASH_API_KEY = "
|
||||
|
||||
# CMD/Batch patterns and prefixes
|
||||
CMD_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
|
||||
CMD_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
|
||||
|
||||
# Unix shell patterns and prefixes
|
||||
UNIX_RC_EXPORT_PATTERN = re.compile(r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE)
|
||||
UNIX_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
|
||||
|
||||
|
||||
def is_powershell() -> bool:
|
||||
"""Detect if we're running in PowerShell on Windows.
|
||||
|
||||
Uses multiple heuristics:
|
||||
1. PSModulePath environment variable (PowerShell always sets this)
|
||||
2. COMSPEC pointing to powershell.exe
|
||||
3. TERM_PROGRAM indicating Windows Terminal (often uses PowerShell)
|
||||
"""
|
||||
if os.name != "nt":
|
||||
return False
|
||||
|
||||
# Primary check: PSMODULEPATH is set by PowerShell
|
||||
# This is the most reliable indicator as PowerShell always sets this
|
||||
ps_module_path = os.environ.get("PSMODULEPATH")
|
||||
if ps_module_path:
|
||||
logger.debug("shell_utils.py:is_powershell - Detected PowerShell via PSModulePath")
|
||||
return True
|
||||
|
||||
# Secondary check: COMSPEC points to PowerShell
|
||||
comspec = os.environ.get("COMSPEC", "").lower()
|
||||
if "powershell" in comspec:
|
||||
logger.debug(f"shell_utils.py:is_powershell - Detected PowerShell via COMSPEC: {comspec}")
|
||||
return True
|
||||
|
||||
# Tertiary check: Windows Terminal often uses PowerShell by default
|
||||
# But we only use this if other indicators are ambiguous
|
||||
term_program = os.environ.get("TERM_PROGRAM", "").lower()
|
||||
# Check if we can find evidence of CMD (cmd.exe in COMSPEC)
|
||||
# If not, assume PowerShell for Windows Terminal
|
||||
if "windows" in term_program and "terminal" in term_program and "cmd.exe" not in comspec:
|
||||
logger.debug(f"shell_utils.py:is_powershell - Detected PowerShell via Windows Terminal (COMSPEC: {comspec})")
|
||||
return True
|
||||
|
||||
logger.debug(f"shell_utils.py:is_powershell - Not PowerShell (COMSPEC: {comspec}, TERM_PROGRAM: {term_program})")
|
||||
return False
|
||||
|
||||
|
||||
def read_api_key_from_shell_config() -> Optional[str]:
|
||||
"""Read API key from shell configuration file."""
|
||||
shell_rc_path = get_shell_rc_path()
|
||||
# Ensure shell_rc_path is a Path object for consistent handling
|
||||
if not isinstance(shell_rc_path, Path):
|
||||
shell_rc_path = Path(shell_rc_path)
|
||||
|
||||
# Determine the correct pattern to use based on the file extension and platform
|
||||
if os.name == "nt": # Windows
|
||||
pattern = POWERSHELL_RC_EXPORT_PATTERN if shell_rc_path.suffix == ".ps1" else CMD_RC_EXPORT_PATTERN
|
||||
else: # Unix-like
|
||||
pattern = UNIX_RC_EXPORT_PATTERN
|
||||
|
||||
try:
|
||||
shell_rc_path = get_shell_rc_path()
|
||||
with open(shell_rc_path, encoding="utf8") as shell_rc: # noqa: PTH123
|
||||
# Convert Path to string using as_posix() for cross-platform path compatibility
|
||||
shell_rc_path_str = shell_rc_path.as_posix() if isinstance(shell_rc_path, Path) else str(shell_rc_path)
|
||||
with open(shell_rc_path_str, encoding="utf8") as shell_rc: # noqa: PTH123
|
||||
shell_contents = shell_rc.read()
|
||||
matches = SHELL_RC_EXPORT_PATTERN.findall(shell_contents)
|
||||
return matches[-1] if matches else None
|
||||
matches = pattern.findall(shell_contents)
|
||||
if matches:
|
||||
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - Found API key in file: {shell_rc_path}")
|
||||
return matches[-1]
|
||||
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - No API key found in file: {shell_rc_path}")
|
||||
return None
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - File not found: {shell_rc_path}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - Error reading file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_shell_rc_path() -> Path:
|
||||
"""Get the path to the user's shell configuration file."""
|
||||
if os.name == "nt": # on Windows, we use a batch file in the user's home directory
|
||||
if os.name == "nt": # Windows
|
||||
if is_powershell():
|
||||
return Path.home() / "codeflash_env.ps1"
|
||||
return Path.home() / "codeflash_env.bat"
|
||||
shell = os.environ.get("SHELL", "/bin/bash").split("/")[-1]
|
||||
shell_rc_filename = {"zsh": ".zshrc", "ksh": ".kshrc", "csh": ".cshrc", "tcsh": ".cshrc", "dash": ".profile"}.get(
|
||||
|
|
@ -44,40 +111,123 @@ def get_shell_rc_path() -> Path:
|
|||
|
||||
|
||||
def get_api_key_export_line(api_key: str) -> str:
|
||||
return f'{SHELL_RC_EXPORT_PREFIX}"{api_key}"'
|
||||
"""Get the appropriate export line based on the shell type."""
|
||||
if os.name == "nt": # Windows
|
||||
if is_powershell():
|
||||
return f'{POWERSHELL_RC_EXPORT_PREFIX}"{api_key}"'
|
||||
return f'{CMD_RC_EXPORT_PREFIX}"{api_key}"'
|
||||
# Unix-like
|
||||
return f'{UNIX_RC_EXPORT_PREFIX}"{api_key}"'
|
||||
|
||||
|
||||
def save_api_key_to_rc(api_key: str) -> Result[str, str]:
|
||||
"""Save API key to the appropriate shell configuration file."""
|
||||
shell_rc_path = get_shell_rc_path()
|
||||
# Ensure shell_rc_path is a Path object for consistent handling
|
||||
if not isinstance(shell_rc_path, Path):
|
||||
shell_rc_path = Path(shell_rc_path)
|
||||
api_key_line = get_api_key_export_line(api_key)
|
||||
try:
|
||||
with open(shell_rc_path, "r+", encoding="utf8") as shell_file: # noqa: PTH123
|
||||
shell_contents = shell_file.read()
|
||||
if os.name == "nt" and not shell_contents: # on windows we're writing to a batch file
|
||||
shell_contents = "@echo off"
|
||||
existing_api_key = read_api_key_from_shell_config()
|
||||
|
||||
if existing_api_key:
|
||||
# Replace the existing API key line
|
||||
updated_shell_contents = re.sub(SHELL_RC_EXPORT_PATTERN, api_key_line, shell_contents)
|
||||
action = "Updated CODEFLASH_API_KEY in"
|
||||
else:
|
||||
logger.debug(f"shell_utils.py:save_api_key_to_rc - Saving API key to: {shell_rc_path}")
|
||||
logger.debug(f"shell_utils.py:save_api_key_to_rc - API key line format: {api_key_line[:30]}...")
|
||||
|
||||
# Determine the correct pattern to use for replacement
|
||||
if os.name == "nt": # Windows
|
||||
if is_powershell():
|
||||
pattern = POWERSHELL_RC_EXPORT_PATTERN
|
||||
logger.debug("shell_utils.py:save_api_key_to_rc - Using PowerShell pattern")
|
||||
else:
|
||||
pattern = CMD_RC_EXPORT_PATTERN
|
||||
logger.debug("shell_utils.py:save_api_key_to_rc - Using CMD pattern")
|
||||
else: # Unix-like
|
||||
pattern = UNIX_RC_EXPORT_PATTERN
|
||||
logger.debug("shell_utils.py:save_api_key_to_rc - Using Unix pattern")
|
||||
|
||||
try:
|
||||
# Create directory if it doesn't exist (ignore errors - file operation will fail if needed)
|
||||
# Directory creation failed, but we'll still try to open the file
|
||||
# The file operation itself will raise the appropriate exception if there are permission issues
|
||||
with contextlib.suppress(OSError, PermissionError):
|
||||
shell_rc_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Convert Path to string using as_posix() for cross-platform path compatibility
|
||||
shell_rc_path_str = shell_rc_path.as_posix() if isinstance(shell_rc_path, Path) else str(shell_rc_path)
|
||||
|
||||
# Try to open in r+ mode (read and write in single operation)
|
||||
# Handle FileNotFoundError if file doesn't exist (r+ requires file to exist)
|
||||
try:
|
||||
with open(shell_rc_path_str, "r+", encoding="utf8") as shell_file: # noqa: PTH123
|
||||
shell_contents = shell_file.read()
|
||||
logger.debug(f"shell_utils.py:save_api_key_to_rc - Read existing file, length: {len(shell_contents)}")
|
||||
|
||||
# Initialize empty file with header for batch files if needed
|
||||
if not shell_contents:
|
||||
logger.debug("shell_utils.py:save_api_key_to_rc - File is empty, initializing")
|
||||
if os.name == "nt" and not is_powershell():
|
||||
shell_contents = "@echo off"
|
||||
logger.debug("shell_utils.py:save_api_key_to_rc - Added @echo off header for batch file")
|
||||
|
||||
# Check if API key already exists in the current file
|
||||
matches = pattern.findall(shell_contents)
|
||||
existing_in_file = bool(matches)
|
||||
logger.debug(f"shell_utils.py:save_api_key_to_rc - Existing key in file: {existing_in_file}")
|
||||
|
||||
if existing_in_file:
|
||||
# Replace the existing API key line in this file
|
||||
updated_shell_contents = re.sub(pattern, api_key_line, shell_contents)
|
||||
action = "Updated CODEFLASH_API_KEY in"
|
||||
logger.debug("shell_utils.py:save_api_key_to_rc - Replaced existing API key")
|
||||
else:
|
||||
# Append the new API key line
|
||||
if shell_contents and not shell_contents.endswith(LF):
|
||||
updated_shell_contents = shell_contents + LF + api_key_line + LF
|
||||
else:
|
||||
updated_shell_contents = shell_contents.rstrip() + f"{LF}{api_key_line}{LF}"
|
||||
action = "Added CODEFLASH_API_KEY to"
|
||||
logger.debug("shell_utils.py:save_api_key_to_rc - Appended new API key")
|
||||
|
||||
# Write the updated contents
|
||||
shell_file.seek(0)
|
||||
shell_file.write(updated_shell_contents)
|
||||
shell_file.truncate()
|
||||
except FileNotFoundError:
|
||||
# File doesn't exist, create it first with initial content
|
||||
logger.debug("shell_utils.py:save_api_key_to_rc - File does not exist, creating new")
|
||||
shell_contents = ""
|
||||
# Initialize with header for batch files if needed
|
||||
if os.name == "nt" and not is_powershell():
|
||||
shell_contents = "@echo off"
|
||||
logger.debug("shell_utils.py:save_api_key_to_rc - Added @echo off header for batch file")
|
||||
|
||||
# Create the file by opening in write mode
|
||||
with open(shell_rc_path_str, "w", encoding="utf8") as shell_file: # noqa: PTH123
|
||||
shell_file.write(shell_contents)
|
||||
|
||||
# Re-open in r+ mode to add the API key (r+ allows both read and write)
|
||||
with open(shell_rc_path_str, "r+", encoding="utf8") as shell_file: # noqa: PTH123
|
||||
# Append the new API key line
|
||||
updated_shell_contents = shell_contents.rstrip() + f"{LF}{api_key_line}{LF}"
|
||||
action = "Added CODEFLASH_API_KEY to"
|
||||
logger.debug("shell_utils.py:save_api_key_to_rc - Appended new API key to new file")
|
||||
|
||||
# Write the updated contents
|
||||
shell_file.seek(0)
|
||||
shell_file.write(updated_shell_contents)
|
||||
shell_file.truncate()
|
||||
|
||||
logger.debug(f"shell_utils.py:save_api_key_to_rc - Successfully wrote to {shell_rc_path}")
|
||||
|
||||
shell_file.seek(0)
|
||||
shell_file.write(updated_shell_contents)
|
||||
shell_file.truncate()
|
||||
return Success(f"✅ {action} {shell_rc_path}")
|
||||
except PermissionError:
|
||||
except PermissionError as e:
|
||||
logger.debug(f"shell_utils.py:save_api_key_to_rc - Permission error: {e}")
|
||||
return Failure(
|
||||
f"💡 I tried adding your Codeflash API key to {shell_rc_path} - but seems like I don't have permissions to do so.{LF}"
|
||||
f"You'll need to open it yourself and add the following line:{LF}{LF}{api_key_line}{LF}"
|
||||
)
|
||||
except FileNotFoundError:
|
||||
except Exception as e:
|
||||
logger.debug(f"shell_utils.py:save_api_key_to_rc - Error: {e}")
|
||||
return Failure(
|
||||
f"💡 I went to save your Codeflash API key to {shell_rc_path}, but noticed that it doesn't exist.{LF}"
|
||||
f"💡 I went to save your Codeflash API key to {shell_rc_path}, but encountered an error: {e}{LF}"
|
||||
f"To ensure your Codeflash API key is automatically loaded into your environment at startup, you can create {shell_rc_path} and add the following line:{LF}"
|
||||
f"{LF}{api_key_line}{LF}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ class TestShellUtils(unittest.TestCase):
|
|||
api_key = "cf-12345"
|
||||
result = save_api_key_to_rc(api_key)
|
||||
self.assertTrue(isinstance(result, Success))
|
||||
mock_file.assert_called_with("/fake/path/.bashrc", encoding="utf8")
|
||||
mock_file.assert_called_with("/fake/path/.bashrc", "r+", encoding="utf8")
|
||||
handle = mock_file()
|
||||
handle.write.assert_called_once()
|
||||
handle.truncate.assert_called_once()
|
||||
|
|
|
|||
Loading…
Reference in a new issue