Refactor git repo checking to more sane implementation and add tests

This commit is contained in:
afik.cohen 2024-06-03 17:35:19 -07:00
parent 68af58b6a7
commit 19016136b9
6 changed files with 75 additions and 59 deletions

View file

@ -32,7 +32,7 @@
<excludeFolder url="file://$MODULE_DIR$/js/cf-webapp/node_modules" />
<excludeFolder url="file://$MODULE_DIR$/js/common/node_modules" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="jdk" jdkName="$USER_HOME$/miniforge3/envs/codeflash311" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="module" module-name="langchain" />
</component>

View file

@ -11,6 +11,8 @@ from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
from codeflash.code_utils import env_utils
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.git_utils import (
check_running_in_git_repo,
confirm_proceeding_with_no_git_repo,
get_repo_owner_and_name,
)
from codeflash.code_utils.github_utils import get_github_secrets_page_url, require_github_app_or_exit
@ -19,12 +21,12 @@ from codeflash.version import __version__ as version
def parse_args() -> Namespace:
parser = ArgumentParser()
subparsers = parser.add_subparsers(dest='command', help='Sub-commands')
subparsers = parser.add_subparsers(dest="command", help="Sub-commands")
init_parser = subparsers.add_parser('init', help='Initialize Codeflash for a Python project.')
init_parser = subparsers.add_parser("init", help="Initialize Codeflash for a Python project.")
init_parser.set_defaults(func=init_codeflash)
init_actions_parser = subparsers.add_parser('init-actions', help='Initialize GitHub Actions workflow')
init_actions_parser = subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")
init_actions_parser.set_defaults(func=install_github_actions)
parser.add_argument("--file", help="Try to optimize only this file")
parser.add_argument(
@ -94,6 +96,11 @@ def process_cmd_args(args: Namespace) -> Namespace:
if args.command:
args.func()
sys.exit(1)
if not check_running_in_git_repo(module_root=args.module_root):
if not confirm_proceeding_with_no_git_repo():
logging.critical("No git repository detected and user aborted run. Exiting...")
sys.exit(1)
args.no_pr = True
if args.function and not args.file:
logging.error("If you specify a --function, you must specify the --file it is in")
sys.exit(1)

View file

@ -1,12 +1,8 @@
import logging
import os
import sys
from functools import lru_cache
from typing import Optional
import click
import git
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
@ -15,15 +11,14 @@ def get_codeflash_api_key() -> Optional[str]:
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
if not api_key:
raise OSError(
"I didn't find a Codeflash API key in your environment.\n"
+ "You can generate one at https://app.codeflash.ai/app/apikeys,\n"
+ "then set it as a CODEFLASH_API_KEY environment variable.",
"I didn't find a Codeflash API key in your environment.\nYou can generate one at "
"https://app.codeflash.ai/app/apikeys,\nthen set it as a CODEFLASH_API_KEY environment variable.",
)
if not api_key.startswith("cf-"):
raise OSError(
f"Your Codeflash API key seems to be invalid. It should start with a 'cf-' prefix; I found '{api_key}' instead.\n"
+ "You can generate one at https://app.codeflash.ai/app/apikeys,\n"
+ "then set it as a CODEFLASH_API_KEY environment variable.",
f"Your Codeflash API key seems to be invalid. It should start with a 'cf-' prefix; I found '{api_key}' "
f"instead.\nYou can generate one at https://app.codeflash.ai/app/apikeys,\nthen set it as a "
f"CODEFLASH_API_KEY environment variable.",
)
return api_key
@ -33,9 +28,8 @@ def ensure_codeflash_api_key() -> bool:
get_codeflash_api_key()
except OSError:
logging.exception(
"Codeflash API key not found in your environment.\n"
+ "You can generate one at https://app.codeflash.ai/app/apikeys,\n"
+ "then set it as a CODEFLASH_API_KEY environment variable.",
"Codeflash API key not found in your environment.\nYou can generate one at "
"https://app.codeflash.ai/app/apikeys,\nthen set it as a CODEFLASH_API_KEY environment variable.",
)
return False
return True
@ -52,36 +46,13 @@ def get_pr_number() -> Optional[int]:
pr_number = os.environ.get("CODEFLASH_PR_NUMBER")
if not pr_number:
return None
else:
return int(pr_number)
return int(pr_number)
def ensure_pr_number() -> bool:
if not get_pr_number():
raise OSError(
"CODEFLASH_PR_NUMBER not found in environment variables; make sure the Github Action is setting this so Codeflash can comment on the right PR",
"CODEFLASH_PR_NUMBER not found in environment variables; make sure the Github Action is setting this so "
"Codeflash can comment on the right PR",
)
return True
def ensure_git_repo(module_root: str) -> tuple[bool, bool]:
# return type is (should_continue, disable_PR_creation)
try:
_ = git.Repo(module_root, search_parent_directories=True).git_dir
return True, False
except git.exc.InvalidGitRepositoryError:
# Only ask for the prompt if running in non-interactive mode
if sys.__stdin__.isatty():
response = click.prompt(
"I did not find a git repository for the code. If you run codeflash, it might overwrite the"
" code and you might irreversibly lose your current code. Proceed?",
type=click.Choice(["yes", "no"], case_sensitive=False),
show_choices=True,
)
if response == "no":
return False, True
if response == "yes":
return True, True
else:
# continue running, important for GitHub actions
return True, False

View file

@ -1,9 +1,11 @@
import logging
import os
import sys
from io import StringIO
from typing import Optional
import git
import inquirer
from git import Repo
from unidiff import PatchSet
@ -87,3 +89,22 @@ def get_repo_owner_and_name(repo: Optional[Repo] = None) -> tuple[str, str]:
def git_root_dir(repo: Optional[Repo] = None) -> str:
repository: Repo = repo if repo else git.Repo(search_parent_directories=True)
return repository.working_dir
def check_running_in_git_repo(module_root: str) -> bool:
try:
_ = git.Repo(module_root, search_parent_directories=True).git_dir
return True
except git.exc.InvalidGitRepositoryError:
return confirm_proceeding_with_no_git_repo()
def confirm_proceeding_with_no_git_repo() -> bool:
if sys.__stdin__.isatty():
return inquirer.confirm(
"WARNING: I did not find a git repository for your code. If you proceed in running codeflash, optimized code will"
" be written over your current code and you could irreversibly lose your current code. Proceed?",
default=False,
)
# continue running on non-interactive environments, important for GitHub actions
return True

View file

@ -4,7 +4,6 @@ import concurrent.futures
import logging
import os
import pathlib
import sys
import uuid
from argparse import Namespace
from collections import defaultdict
@ -101,12 +100,6 @@ class Optimizer:
logging.info("Running optimizer.")
if not env_utils.ensure_codeflash_api_key():
return
continue_execution, disable_pr = env_utils.ensure_git_repo(module_root=self.args.module_root)
if not continue_execution:
logging.critical("No git repository detected and user aborted run. Exiting...")
sys.exit(1)
if disable_pr:
self.args.no_pr = True
file_to_funcs_to_optimize: dict[str, list[FunctionToOptimize]]
num_optimizable_functions: int

View file

@ -1,7 +1,8 @@
import unittest
from unittest.mock import patch
from codeflash.code_utils.git_utils import get_repo_owner_and_name
import git
from codeflash.code_utils.git_utils import check_running_in_git_repo, get_repo_owner_and_name
class TestGitUtils(unittest.TestCase):
@ -10,26 +11,49 @@ class TestGitUtils(unittest.TestCase):
# Test with a standard GitHub HTTPS URL
mock_get_remote_url.return_value = "https://github.com/owner/repo.git"
owner, repo_name = get_repo_owner_and_name()
self.assertEqual(owner, "owner")
self.assertEqual(repo_name, "repo")
assert owner == "owner"
assert repo_name == "repo"
# Test with a GitHub SSH URL
mock_get_remote_url.return_value = "git@github.com:owner/repo.git"
owner, repo_name = get_repo_owner_and_name()
self.assertEqual(owner, "owner")
self.assertEqual(repo_name, "repo")
assert owner == "owner"
assert repo_name == "repo"
# Test with another GitHub SSH URL
mock_get_remote_url.return_value = "git@github.com:codeflash-ai/posthog.git"
owner, repo_name = get_repo_owner_and_name()
self.assertEqual(owner, "codeflash-ai")
self.assertEqual(repo_name, "posthog")
assert owner == "codeflash-ai"
assert repo_name == "posthog"
# Test with a URL without the .git suffix
mock_get_remote_url.return_value = "https://github.com/owner/repo"
owner, repo_name = get_repo_owner_and_name()
self.assertEqual(owner, "owner")
self.assertEqual(repo_name, "repo")
assert owner == "owner"
assert repo_name == "repo"
@patch("codeflash.code_utils.env_utils.git.Repo")
def test_check_running_in_git_repo_in_git_repo(self, mock_repo):
mock_repo.return_value.git_dir = "/path/to/repo/.git"
assert check_running_in_git_repo("/path/to/repo")
@patch("codeflash.code_utils.env_utils.git.Repo")
@patch("codeflash.code_utils.env_utils.sys.__stdin__.isatty", return_value=True)
@patch("codeflash.code_utils.env_utils.confirm_proceeding_with_no_git_repo", return_value=True)
def test_check_running_in_git_repo_not_in_git_repo_interactive(
self,
mock_confirm,
mock_isatty,
mock_repo,
):
mock_repo.side_effect = git.exc.InvalidGitRepositoryError
assert check_running_in_git_repo("/path/to/non-repo")
@patch("codeflash.code_utils.env_utils.git.Repo")
@patch("codeflash.code_utils.env_utils.sys.__stdin__.isatty", return_value=False)
def test_check_running_in_git_repo_not_in_git_repo_non_interactive(self, mock_isatty, mock_repo):
mock_repo.side_effect = git.exc.InvalidGitRepositoryError
assert check_running_in_git_repo("/path/to/non-repo")
if __name__ == "__main__":