mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
Refactor git repo checking to more sane implementation and add tests
This commit is contained in:
parent
68af58b6a7
commit
19016136b9
6 changed files with 75 additions and 59 deletions
|
|
@ -32,7 +32,7 @@
|
|||
<excludeFolder url="file://$MODULE_DIR$/js/cf-webapp/node_modules" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/js/common/node_modules" />
|
||||
</content>
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="jdk" jdkName="$USER_HOME$/miniforge3/envs/codeflash311" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
<orderEntry type="module" module-name="langchain" />
|
||||
</component>
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
|
|||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
from codeflash.code_utils.git_utils import (
|
||||
check_running_in_git_repo,
|
||||
confirm_proceeding_with_no_git_repo,
|
||||
get_repo_owner_and_name,
|
||||
)
|
||||
from codeflash.code_utils.github_utils import get_github_secrets_page_url, require_github_app_or_exit
|
||||
|
|
@ -19,12 +21,12 @@ from codeflash.version import __version__ as version
|
|||
|
||||
def parse_args() -> Namespace:
|
||||
parser = ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest='command', help='Sub-commands')
|
||||
subparsers = parser.add_subparsers(dest="command", help="Sub-commands")
|
||||
|
||||
init_parser = subparsers.add_parser('init', help='Initialize Codeflash for a Python project.')
|
||||
init_parser = subparsers.add_parser("init", help="Initialize Codeflash for a Python project.")
|
||||
init_parser.set_defaults(func=init_codeflash)
|
||||
|
||||
init_actions_parser = subparsers.add_parser('init-actions', help='Initialize GitHub Actions workflow')
|
||||
init_actions_parser = subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")
|
||||
init_actions_parser.set_defaults(func=install_github_actions)
|
||||
parser.add_argument("--file", help="Try to optimize only this file")
|
||||
parser.add_argument(
|
||||
|
|
@ -94,6 +96,11 @@ def process_cmd_args(args: Namespace) -> Namespace:
|
|||
if args.command:
|
||||
args.func()
|
||||
sys.exit(1)
|
||||
if not check_running_in_git_repo(module_root=args.module_root):
|
||||
if not confirm_proceeding_with_no_git_repo():
|
||||
logging.critical("No git repository detected and user aborted run. Exiting...")
|
||||
sys.exit(1)
|
||||
args.no_pr = True
|
||||
if args.function and not args.file:
|
||||
logging.error("If you specify a --function, you must specify the --file it is in")
|
||||
sys.exit(1)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,8 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import git
|
||||
|
||||
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
|
||||
|
||||
|
||||
|
|
@ -15,15 +11,14 @@ def get_codeflash_api_key() -> Optional[str]:
|
|||
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
|
||||
if not api_key:
|
||||
raise OSError(
|
||||
"I didn't find a Codeflash API key in your environment.\n"
|
||||
+ "You can generate one at https://app.codeflash.ai/app/apikeys,\n"
|
||||
+ "then set it as a CODEFLASH_API_KEY environment variable.",
|
||||
"I didn't find a Codeflash API key in your environment.\nYou can generate one at "
|
||||
"https://app.codeflash.ai/app/apikeys,\nthen set it as a CODEFLASH_API_KEY environment variable.",
|
||||
)
|
||||
if not api_key.startswith("cf-"):
|
||||
raise OSError(
|
||||
f"Your Codeflash API key seems to be invalid. It should start with a 'cf-' prefix; I found '{api_key}' instead.\n"
|
||||
+ "You can generate one at https://app.codeflash.ai/app/apikeys,\n"
|
||||
+ "then set it as a CODEFLASH_API_KEY environment variable.",
|
||||
f"Your Codeflash API key seems to be invalid. It should start with a 'cf-' prefix; I found '{api_key}' "
|
||||
f"instead.\nYou can generate one at https://app.codeflash.ai/app/apikeys,\nthen set it as a "
|
||||
f"CODEFLASH_API_KEY environment variable.",
|
||||
)
|
||||
return api_key
|
||||
|
||||
|
|
@ -33,9 +28,8 @@ def ensure_codeflash_api_key() -> bool:
|
|||
get_codeflash_api_key()
|
||||
except OSError:
|
||||
logging.exception(
|
||||
"Codeflash API key not found in your environment.\n"
|
||||
+ "You can generate one at https://app.codeflash.ai/app/apikeys,\n"
|
||||
+ "then set it as a CODEFLASH_API_KEY environment variable.",
|
||||
"Codeflash API key not found in your environment.\nYou can generate one at "
|
||||
"https://app.codeflash.ai/app/apikeys,\nthen set it as a CODEFLASH_API_KEY environment variable.",
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
|
@ -52,36 +46,13 @@ def get_pr_number() -> Optional[int]:
|
|||
pr_number = os.environ.get("CODEFLASH_PR_NUMBER")
|
||||
if not pr_number:
|
||||
return None
|
||||
else:
|
||||
return int(pr_number)
|
||||
return int(pr_number)
|
||||
|
||||
|
||||
def ensure_pr_number() -> bool:
|
||||
if not get_pr_number():
|
||||
raise OSError(
|
||||
"CODEFLASH_PR_NUMBER not found in environment variables; make sure the Github Action is setting this so Codeflash can comment on the right PR",
|
||||
"CODEFLASH_PR_NUMBER not found in environment variables; make sure the Github Action is setting this so "
|
||||
"Codeflash can comment on the right PR",
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def ensure_git_repo(module_root: str) -> tuple[bool, bool]:
|
||||
# return type is (should_continue, disable_PR_creation)
|
||||
try:
|
||||
_ = git.Repo(module_root, search_parent_directories=True).git_dir
|
||||
return True, False
|
||||
except git.exc.InvalidGitRepositoryError:
|
||||
# Only ask for the prompt if running in non-interactive mode
|
||||
if sys.__stdin__.isatty():
|
||||
response = click.prompt(
|
||||
"I did not find a git repository for the code. If you run codeflash, it might overwrite the"
|
||||
" code and you might irreversibly lose your current code. Proceed?",
|
||||
type=click.Choice(["yes", "no"], case_sensitive=False),
|
||||
show_choices=True,
|
||||
)
|
||||
if response == "no":
|
||||
return False, True
|
||||
if response == "yes":
|
||||
return True, True
|
||||
else:
|
||||
# continue running, important for GitHub actions
|
||||
return True, False
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from io import StringIO
|
||||
from typing import Optional
|
||||
|
||||
import git
|
||||
import inquirer
|
||||
from git import Repo
|
||||
from unidiff import PatchSet
|
||||
|
||||
|
|
@ -87,3 +89,22 @@ def get_repo_owner_and_name(repo: Optional[Repo] = None) -> tuple[str, str]:
|
|||
def git_root_dir(repo: Optional[Repo] = None) -> str:
|
||||
repository: Repo = repo if repo else git.Repo(search_parent_directories=True)
|
||||
return repository.working_dir
|
||||
|
||||
|
||||
def check_running_in_git_repo(module_root: str) -> bool:
|
||||
try:
|
||||
_ = git.Repo(module_root, search_parent_directories=True).git_dir
|
||||
return True
|
||||
except git.exc.InvalidGitRepositoryError:
|
||||
return confirm_proceeding_with_no_git_repo()
|
||||
|
||||
|
||||
def confirm_proceeding_with_no_git_repo() -> bool:
|
||||
if sys.__stdin__.isatty():
|
||||
return inquirer.confirm(
|
||||
"WARNING: I did not find a git repository for your code. If you proceed in running codeflash, optimized code will"
|
||||
" be written over your current code and you could irreversibly lose your current code. Proceed?",
|
||||
default=False,
|
||||
)
|
||||
# continue running on non-interactive environments, important for GitHub actions
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import concurrent.futures
|
|||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import uuid
|
||||
from argparse import Namespace
|
||||
from collections import defaultdict
|
||||
|
|
@ -101,12 +100,6 @@ class Optimizer:
|
|||
logging.info("Running optimizer.")
|
||||
if not env_utils.ensure_codeflash_api_key():
|
||||
return
|
||||
continue_execution, disable_pr = env_utils.ensure_git_repo(module_root=self.args.module_root)
|
||||
if not continue_execution:
|
||||
logging.critical("No git repository detected and user aborted run. Exiting...")
|
||||
sys.exit(1)
|
||||
if disable_pr:
|
||||
self.args.no_pr = True
|
||||
|
||||
file_to_funcs_to_optimize: dict[str, list[FunctionToOptimize]]
|
||||
num_optimizable_functions: int
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from codeflash.code_utils.git_utils import get_repo_owner_and_name
|
||||
import git
|
||||
from codeflash.code_utils.git_utils import check_running_in_git_repo, get_repo_owner_and_name
|
||||
|
||||
|
||||
class TestGitUtils(unittest.TestCase):
|
||||
|
|
@ -10,26 +11,49 @@ class TestGitUtils(unittest.TestCase):
|
|||
# Test with a standard GitHub HTTPS URL
|
||||
mock_get_remote_url.return_value = "https://github.com/owner/repo.git"
|
||||
owner, repo_name = get_repo_owner_and_name()
|
||||
self.assertEqual(owner, "owner")
|
||||
self.assertEqual(repo_name, "repo")
|
||||
assert owner == "owner"
|
||||
assert repo_name == "repo"
|
||||
|
||||
# Test with a GitHub SSH URL
|
||||
mock_get_remote_url.return_value = "git@github.com:owner/repo.git"
|
||||
owner, repo_name = get_repo_owner_and_name()
|
||||
self.assertEqual(owner, "owner")
|
||||
self.assertEqual(repo_name, "repo")
|
||||
assert owner == "owner"
|
||||
assert repo_name == "repo"
|
||||
|
||||
# Test with another GitHub SSH URL
|
||||
mock_get_remote_url.return_value = "git@github.com:codeflash-ai/posthog.git"
|
||||
owner, repo_name = get_repo_owner_and_name()
|
||||
self.assertEqual(owner, "codeflash-ai")
|
||||
self.assertEqual(repo_name, "posthog")
|
||||
assert owner == "codeflash-ai"
|
||||
assert repo_name == "posthog"
|
||||
|
||||
# Test with a URL without the .git suffix
|
||||
mock_get_remote_url.return_value = "https://github.com/owner/repo"
|
||||
owner, repo_name = get_repo_owner_and_name()
|
||||
self.assertEqual(owner, "owner")
|
||||
self.assertEqual(repo_name, "repo")
|
||||
assert owner == "owner"
|
||||
assert repo_name == "repo"
|
||||
|
||||
@patch("codeflash.code_utils.env_utils.git.Repo")
|
||||
def test_check_running_in_git_repo_in_git_repo(self, mock_repo):
|
||||
mock_repo.return_value.git_dir = "/path/to/repo/.git"
|
||||
assert check_running_in_git_repo("/path/to/repo")
|
||||
|
||||
@patch("codeflash.code_utils.env_utils.git.Repo")
|
||||
@patch("codeflash.code_utils.env_utils.sys.__stdin__.isatty", return_value=True)
|
||||
@patch("codeflash.code_utils.env_utils.confirm_proceeding_with_no_git_repo", return_value=True)
|
||||
def test_check_running_in_git_repo_not_in_git_repo_interactive(
|
||||
self,
|
||||
mock_confirm,
|
||||
mock_isatty,
|
||||
mock_repo,
|
||||
):
|
||||
mock_repo.side_effect = git.exc.InvalidGitRepositoryError
|
||||
assert check_running_in_git_repo("/path/to/non-repo")
|
||||
|
||||
@patch("codeflash.code_utils.env_utils.git.Repo")
|
||||
@patch("codeflash.code_utils.env_utils.sys.__stdin__.isatty", return_value=False)
|
||||
def test_check_running_in_git_repo_not_in_git_repo_non_interactive(self, mock_isatty, mock_repo):
|
||||
mock_repo.side_effect = git.exc.InvalidGitRepositoryError
|
||||
assert check_running_in_git_repo("/path/to/non-repo")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in a new issue