mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Shell util fixes and add tests
This commit is contained in:
parent
0b23670bf2
commit
1360397448
2 changed files with 157 additions and 10 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from returns.result import Failure, Result, Success
|
||||
|
|
@ -10,24 +11,27 @@ if os.name == "nt": # Windows
|
|||
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(.*)$", re.M)
|
||||
SHELL_RC_EXPORT_PREFIX = f"set CODEFLASH_API_KEY="
|
||||
else:
|
||||
SHELL_RC_EXPORT_PATTERN = re.compile(r'^export CODEFLASH_API_KEY="?(.*)"?$', re.M)
|
||||
SHELL_RC_EXPORT_PATTERN = re.compile(r'^export CODEFLASH_API_KEY="?([^"]*)"?$', re.M)
|
||||
SHELL_RC_EXPORT_PREFIX = f"export CODEFLASH_API_KEY="
|
||||
|
||||
|
||||
def read_api_key_from_shell_config() -> Optional[str]:
|
||||
shell_rc_path = get_shell_rc_path()
|
||||
with open(shell_rc_path, "r", encoding="utf8") as shell_rc:
|
||||
shell_contents = shell_rc.read()
|
||||
match = SHELL_RC_EXPORT_PATTERN.search(shell_contents)
|
||||
return match.group(1) if match else None
|
||||
try:
|
||||
shell_rc_path = get_shell_rc_path()
|
||||
with open(shell_rc_path, "r", encoding="utf8") as shell_rc:
|
||||
shell_contents = shell_rc.read()
|
||||
match = SHELL_RC_EXPORT_PATTERN.search(shell_contents)
|
||||
return match.group(1) if match else None
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
def get_shell_rc_path() -> str:
|
||||
"""Get the path to the user's shell configuration file."""
|
||||
if os.name == "nt": # on Windows, we use a batch file in the user's home directory
|
||||
return os.path.expanduser("~\\codeflash_env.bat")
|
||||
return str(Path.home() / "codeflash_env.bat")
|
||||
else:
|
||||
shell = os.environ["SHELL"].split("/")[-1] if "SHELL" in os.environ else "/bin/bash"
|
||||
shell = os.environ.get("SHELL", "/bin/bash").split("/")[-1]
|
||||
shell_rc_filename = {
|
||||
"zsh": ".zshrc",
|
||||
"ksh": ".kshrc",
|
||||
|
|
@ -37,7 +41,7 @@ def get_shell_rc_path() -> str:
|
|||
}.get(
|
||||
shell, ".bashrc"
|
||||
) # map each shell to its config file and default to .bashrc
|
||||
return os.path.expanduser(f"~/{shell_rc_filename}")
|
||||
return str(Path.home() / shell_rc_filename)
|
||||
|
||||
|
||||
def save_api_key_to_rc(api_key) -> Result[str, str]:
|
||||
|
|
@ -66,8 +70,13 @@ def save_api_key_to_rc(api_key) -> Result[str, str]:
|
|||
shell_file.write(updated_shell_contents)
|
||||
shell_file.truncate()
|
||||
return Success(f"✅ {action} {shell_rc_path}.")
|
||||
except IOError as e:
|
||||
except PermissionError:
|
||||
return Failure(
|
||||
f"💡 I tried adding your CodeFlash API key to {shell_rc_path} - but seems like I don't have permissions to do so.{LF}"
|
||||
f"You'll need to open it yourself and add the following line:{LF}{LF}{api_key_line}{LF}"
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return Failure(
|
||||
f"💡 I couldn't find your shell configuration file at {shell_rc_path}.{LF}"
|
||||
f"Please create it and add the following line:{LF}{LF}{api_key_line}{LF}"
|
||||
)
|
||||
|
|
|
|||
138
tests/test_shell_utils.py
Normal file
138
tests/test_shell_utils.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
import os
|
||||
import unittest
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
from returns.result import Success, Failure
|
||||
|
||||
from codeflash.code_utils.shell_utils import save_api_key_to_rc, read_api_key_from_shell_config
|
||||
|
||||
|
||||
class TestShellUtils(unittest.TestCase):
|
||||
|
||||
@patch(
|
||||
"codeflash.code_utils.shell_utils.open",
|
||||
new_callable=mock_open,
|
||||
read_data="existing content",
|
||||
)
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
def test_save_api_key_to_rc_success(self, mock_get_shell_rc_path, mock_file):
|
||||
mock_get_shell_rc_path.return_value = "/fake/path/.bashrc"
|
||||
api_key = "cf-12345"
|
||||
result = save_api_key_to_rc(api_key)
|
||||
self.assertTrue(isinstance(result, Success))
|
||||
mock_file.assert_called_once_with("/fake/path/.bashrc", "r+", encoding="utf8")
|
||||
handle = mock_file()
|
||||
handle.write.assert_called_once()
|
||||
handle.truncate.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"codeflash.code_utils.shell_utils.open",
|
||||
new_callable=mock_open,
|
||||
read_data="existing content",
|
||||
)
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
def test_save_api_key_to_rc_failure(self, mock_get_shell_rc_path, mock_file):
|
||||
mock_get_shell_rc_path.return_value = "/fake/path/.bashrc"
|
||||
mock_file.side_effect = IOError("Permission denied")
|
||||
api_key = "cf-12345"
|
||||
result = save_api_key_to_rc(api_key)
|
||||
self.assertTrue(isinstance(result, Failure))
|
||||
mock_file.assert_called_once_with("/fake/path/.bashrc", "r+", encoding="utf8")
|
||||
|
||||
|
||||
# unit tests
|
||||
class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Setup a temporary shell configuration file for testing."""
|
||||
self.test_rc_path = "test_shell_rc"
|
||||
self.api_key = "cf-1234567890abcdef"
|
||||
os.environ["SHELL"] = "/bin/bash" # Set a default shell for testing
|
||||
|
||||
def tearDown(self):
|
||||
"""Cleanup the temporary shell configuration file after testing."""
|
||||
if os.path.exists(self.test_rc_path):
|
||||
os.remove(self.test_rc_path)
|
||||
del os.environ["SHELL"] # Remove the SHELL environment variable
|
||||
|
||||
def test_valid_api_key(self):
|
||||
"""Test with a valid API key export."""
|
||||
with open(self.test_rc_path, "w", encoding="utf8") as f:
|
||||
f.write(f'export CODEFLASH_API_KEY="{self.api_key}"\n')
|
||||
with patch(
|
||||
"builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY="{self.api_key}"\n')
|
||||
):
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
def test_no_api_key(self, mock_get_shell_rc_path):
|
||||
"""Test with no API key export."""
|
||||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch("builtins.open", mock_open(read_data="# No API key here\n")):
|
||||
self.assertIsNone(read_api_key_from_shell_config())
|
||||
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
def test_malformed_api_key_export(self, mock_get_shell_rc_path):
|
||||
"""Test with a malformed API key export."""
|
||||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch(
|
||||
"builtins.open", mock_open(read_data=f"export CODEFLASH_API_KEY={self.api_key}\n")
|
||||
):
|
||||
self.assertIsNone(read_api_key_from_shell_config())
|
||||
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
def test_multiple_api_key_exports(self, mock_get_shell_rc_path):
|
||||
"""Test with multiple API key exports."""
|
||||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch(
|
||||
"builtins.open",
|
||||
mock_open(
|
||||
read_data=f'export CODEFLASH_API_KEY="firstkey"\nexport CODEFLASH_API_KEY="{self.api_key}"\n'
|
||||
),
|
||||
):
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
def test_api_key_export_with_extra_text(self, mock_get_shell_rc_path):
|
||||
"""Test with extra text around API key export."""
|
||||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch(
|
||||
"builtins.open",
|
||||
mock_open(
|
||||
read_data=f'# Setting API Key\nexport CODEFLASH_API_KEY="{self.api_key}"\n# Done\n'
|
||||
),
|
||||
):
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
def test_api_key_in_comment(self, mock_get_shell_rc_path):
|
||||
"""Test with API key export in a comment."""
|
||||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch(
|
||||
"builtins.open", mock_open(read_data=f'# export CODEFLASH_API_KEY="{self.api_key}"\n')
|
||||
):
|
||||
self.assertIsNone(read_api_key_from_shell_config())
|
||||
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
def test_file_does_not_exist(self, mock_get_shell_rc_path):
|
||||
"""Test when the shell configuration file does not exist."""
|
||||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch("builtins.open", mock_open(read_data=""), side_effect=FileNotFoundError):
|
||||
self.assertIsNone(read_api_key_from_shell_config())
|
||||
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
def test_file_not_readable(self, mock_get_shell_rc_path):
|
||||
"""Test when the shell configuration file is not readable."""
|
||||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch(
|
||||
"builtins.open",
|
||||
mock_open(read_data=f'export CODEFLASH_API_KEY="{self.api_key}"\n'),
|
||||
side_effect=PermissionError,
|
||||
):
|
||||
with self.assertRaises(PermissionError):
|
||||
read_api_key_from_shell_config()
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in a new issue