Handle case where multiple keys may be defined in the rc

This commit is contained in:
afik.cohen 2024-03-08 16:41:53 -08:00
parent 5c8131ae26
commit af87d96903
2 changed files with 22 additions and 19 deletions

View file

@ -8,10 +8,10 @@ from returns.result import Failure, Result, Success
from codeflash.code_utils.compat import LF
if os.name == "nt": # Windows
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(.*)$", re.M)
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.M)
SHELL_RC_EXPORT_PREFIX = f"set CODEFLASH_API_KEY="
else:
SHELL_RC_EXPORT_PATTERN = re.compile(r'^export CODEFLASH_API_KEY="?([^"]*)"?$', re.M)
SHELL_RC_EXPORT_PATTERN = re.compile(r'^export CODEFLASH_API_KEY="?(cf-[^"]*)"?$', re.M)
SHELL_RC_EXPORT_PREFIX = f"export CODEFLASH_API_KEY="
@ -20,8 +20,8 @@ def read_api_key_from_shell_config() -> Optional[str]:
shell_rc_path = get_shell_rc_path()
with open(shell_rc_path, "r", encoding="utf8") as shell_rc:
shell_contents = shell_rc.read()
match = SHELL_RC_EXPORT_PATTERN.search(shell_contents)
return match.group(1) if match else None
matches = SHELL_RC_EXPORT_PATTERN.findall(shell_contents)
return matches[-1] if matches else None
except FileNotFoundError:
return None

View file

@ -20,7 +20,7 @@ class TestShellUtils(unittest.TestCase):
api_key = "cf-12345"
result = save_api_key_to_rc(api_key)
self.assertTrue(isinstance(result, Success))
mock_file.assert_called_with("/fake/path/.bashrc", "r+", encoding="utf8")
mock_file.assert_called_with("/fake/path/.bashrc", "r", encoding="utf8")
handle = mock_file()
handle.write.assert_called_once()
handle.truncate.assert_called_once()
@ -33,11 +33,11 @@ class TestShellUtils(unittest.TestCase):
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
def test_save_api_key_to_rc_failure(self, mock_get_shell_rc_path, mock_file):
mock_get_shell_rc_path.return_value = "/fake/path/.bashrc"
mock_file.side_effect = IOError("Permission denied")
mock_file.side_effect = PermissionError
api_key = "cf-12345"
result = save_api_key_to_rc(api_key)
self.assertTrue(isinstance(result, Failure))
mock_file.assert_called_once_with("/fake/path/.bashrc", "r+", encoding="utf8")
mock_file.assert_called_with("/fake/path/.bashrc", "r+", encoding="utf8")
# unit tests
@ -76,11 +76,17 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
def test_malformed_api_key_export(self, mock_get_shell_rc_path):
"""Test with a malformed API key export."""
mock_get_shell_rc_path.return_value = self.test_rc_path
with patch("builtins.open", mock_open(read_data=f"export API_KEY={self.api_key}\n")):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
with patch("builtins.open", mock_open(read_data=f"CODEFLASH_API_KEY={self.api_key}\n")):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
with patch(
"builtins.open", mock_open(read_data=f"export CODEFLASH_API_KEY={self.api_key}\n")
) as mock_file:
self.assertIsNone(read_api_key_from_shell_config())
mock_file.assert_called_once_with(self.test_rc_path, "r", encoding="utf8")
"builtins.open", mock_open(read_data=f"export CODEFLASH_API_KEY=sk-{self.api_key}\n")
):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
def test_multiple_api_key_exports(self, mock_get_shell_rc_path):
@ -89,11 +95,10 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
with patch(
"builtins.open",
mock_open(
read_data=f'export CODEFLASH_API_KEY="firstkey"\nexport CODEFLASH_API_KEY="{self.api_key}"\n'
read_data=f'export CODEFLASH_API_KEY="cf-firstkey"\nexport CODEFLASH_API_KEY="{self.api_key}"\n'
),
) as mock_file:
):
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
mock_file.assert_called_once_with(self.test_rc_path, "r", encoding="utf8")
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
def test_api_key_export_with_extra_text(self, mock_get_shell_rc_path):
@ -127,11 +132,9 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
def test_file_not_readable(self, mock_get_shell_rc_path):
"""Test when the shell configuration file is not readable."""
mock_get_shell_rc_path.return_value = self.test_rc_path
with patch("builtins.open", side_effect=PermissionError):
with self.assertRaises(PermissionError):
read_api_key_from_shell_config()
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
with patch("builtins.open", mock_open(read_data="")):
mock_open.side_effect = PermissionError
self.assertIsNone(read_api_key_from_shell_config())
if __name__ == "__main__":