mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Make cli and django all RUFF and TUFF
This commit is contained in:
parent
87bbc8c238
commit
cf88e2b7d0
30 changed files with 243 additions and 257 deletions
|
|
@ -19,10 +19,9 @@ else:
|
|||
|
||||
|
||||
def make_cfapi_request(
|
||||
endpoint: str, method: str, payload: Optional[Dict[str, Any]] = None
|
||||
endpoint: str, method: str, payload: Optional[Dict[str, Any]] = None,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Make an HTTP request using the specified method, URL, headers, and JSON payload.
|
||||
"""Make an HTTP request using the specified method, URL, headers, and JSON payload.
|
||||
:param endpoint: The endpoint URL to send the request to.
|
||||
:param method: The HTTP method to use ('GET', 'POST', etc.).
|
||||
:param payload: Optional JSON payload to include in the POST request body.
|
||||
|
|
@ -41,8 +40,7 @@ def make_cfapi_request(
|
|||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_user_id() -> Optional[str]:
|
||||
"""
|
||||
Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
|
||||
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
|
||||
:return: The userid or None if the request fails.
|
||||
"""
|
||||
response = make_cfapi_request(endpoint="/cli-get-user", method="GET")
|
||||
|
|
@ -50,7 +48,7 @@ def get_user_id() -> Optional[str]:
|
|||
return response.text
|
||||
else:
|
||||
logging.error(
|
||||
f"Failed to look up your userid; is your CF API key valid? ({response.reason})"
|
||||
f"Failed to look up your userid; is your CF API key valid? ({response.reason})",
|
||||
)
|
||||
return None
|
||||
|
||||
|
|
@ -64,8 +62,7 @@ def suggest_changes(
|
|||
existing_tests: str,
|
||||
generated_tests: str,
|
||||
) -> Response:
|
||||
"""
|
||||
Suggest changes to a pull request.
|
||||
"""Suggest changes to a pull request.
|
||||
Will make a review suggestion when possible;
|
||||
or create a new dependent pull request with the suggested changes.
|
||||
:param owner: The owner of the repository.
|
||||
|
|
@ -98,8 +95,7 @@ def create_pr(
|
|||
existing_tests: str,
|
||||
generated_tests: str,
|
||||
) -> Response:
|
||||
"""
|
||||
Create a pull request, targeting the specified branch. (usually 'main')
|
||||
"""Create a pull request, targeting the specified branch. (usually 'main')
|
||||
:param owner: The owner of the repository.
|
||||
:param repo: The name of the repository.
|
||||
:param base_branch: The base branch to target.
|
||||
|
|
@ -122,8 +118,7 @@ def create_pr(
|
|||
|
||||
|
||||
def check_github_app_installed_on_repo(owner: str, repo: str) -> Response:
|
||||
"""
|
||||
Check if the Codeflash GitHub App is installed on the specified repository.
|
||||
"""Check if the Codeflash GitHub App is installed on the specified repository.
|
||||
:param owner: The owner of the repository.
|
||||
:param repo: The name of the repository.
|
||||
:return: The response object.
|
||||
|
|
|
|||
|
|
@ -1,18 +1,18 @@
|
|||
import logging
|
||||
import os
|
||||
from argparse import Namespace, ArgumentParser, SUPPRESS
|
||||
from argparse import SUPPRESS, ArgumentParser, Namespace
|
||||
|
||||
import git
|
||||
|
||||
from codeflash.api.cfapi import check_github_app_installed_on_repo
|
||||
from codeflash.cli_cmds import logging_config
|
||||
from codeflash.cli_cmds.cmd_init import init_codeflash, apologize_and_exit
|
||||
from codeflash.cli_cmds.cmd_init import apologize_and_exit, init_codeflash
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.compat import LF
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
from codeflash.code_utils.git_utils import (
|
||||
get_repo_owner_and_name,
|
||||
get_github_secrets_page_url,
|
||||
get_repo_owner_and_name,
|
||||
)
|
||||
from codeflash.version import __version__ as version
|
||||
|
||||
|
|
@ -87,7 +87,7 @@ def process_cmd_args(args: Namespace) -> Namespace:
|
|||
try:
|
||||
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
|
||||
except ValueError as e:
|
||||
logging.error(e.args[0])
|
||||
logging.exception(e.args[0])
|
||||
exit(1)
|
||||
supported_keys = [
|
||||
"module_root",
|
||||
|
|
@ -108,10 +108,10 @@ def process_cmd_args(args: Namespace) -> Namespace:
|
|||
) or not hasattr(args, key.replace("-", "_")):
|
||||
setattr(args, key.replace("-", "_"), pyproject_config[key])
|
||||
assert args.module_root is not None and os.path.isdir(
|
||||
args.module_root
|
||||
args.module_root,
|
||||
), f"--module-root {args.module_root} must be a valid directory"
|
||||
assert args.tests_root is not None and os.path.isdir(
|
||||
args.tests_root
|
||||
args.tests_root,
|
||||
), f"--tests-root {args.tests_root} must be a valid directory"
|
||||
assert not (
|
||||
env_utils.get_pr_number() is not None and not env_utils.ensure_codeflash_api_key()
|
||||
|
|
@ -127,7 +127,7 @@ def process_cmd_args(args: Namespace) -> Namespace:
|
|||
normalized_ignore_paths = []
|
||||
for path in args.ignore_paths:
|
||||
assert os.path.exists(
|
||||
path
|
||||
path,
|
||||
), f"ignore-paths config must be a valid path. Path {path} does not exist"
|
||||
normalized_ignore_paths.append(os.path.realpath(path))
|
||||
args.ignore_paths = normalized_ignore_paths
|
||||
|
|
@ -150,9 +150,9 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
|
|||
try:
|
||||
git_repo = git.Repo(search_parent_directories=True)
|
||||
except git.exc.InvalidGitRepositoryError:
|
||||
logging.error(
|
||||
logging.exception(
|
||||
"I couldn't find a git repository in the current directory. "
|
||||
"I need a git repository to run --all and open PRs for optimizations. Exiting..."
|
||||
"I need a git repository to run --all and open PRs for optimizations. Exiting...",
|
||||
)
|
||||
apologize_and_exit()
|
||||
owner, repo = get_repo_owner_and_name(git_repo)
|
||||
|
|
@ -161,16 +161,16 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
|
|||
if not response.ok or response.text != "true":
|
||||
logging.error(f"Error: {response.text}")
|
||||
raise Exception
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
except Exception:
|
||||
logging.exception(
|
||||
f"Could not find the Codeflash GitHub App installed on the repository {owner}/{repo} or the GitHub"
|
||||
f" account linked to your CODEFLASH_API_KEY does not have access to the repository {owner}/{repo}.{LF}"
|
||||
"Please install the Codeflash GitHub App on your repository to use --all. You can install it by going to "
|
||||
f"https://github.com/settings/installations/{LF}"
|
||||
f"https://github.com/settings/installations/{LF}",
|
||||
)
|
||||
apologize_and_exit()
|
||||
if not hasattr(args, "all"):
|
||||
setattr(args, "all", None)
|
||||
args.all = None
|
||||
elif args.all == "":
|
||||
# The default behavior of --all is to optimize everything in args.module_root
|
||||
args.all = args.module_root
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import re
|
|||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Optional, NoReturn
|
||||
from typing import NoReturn, Optional
|
||||
|
||||
import click
|
||||
import inquirer
|
||||
|
|
@ -20,8 +20,8 @@ from codeflash.code_utils.env_utils import (
|
|||
)
|
||||
from codeflash.code_utils.git_utils import get_github_secrets_page_url
|
||||
from codeflash.code_utils.shell_utils import (
|
||||
save_api_key_to_rc,
|
||||
get_shell_rc_path,
|
||||
save_api_key_to_rc,
|
||||
)
|
||||
from codeflash.telemetry.posthog import ph
|
||||
from codeflash.version import __version__ as version
|
||||
|
|
@ -69,11 +69,11 @@ def init_codeflash() -> None:
|
|||
f" codeflash --all to optimize all functions in all files in the module you selected ({setup_info.module_root}){LF}"
|
||||
# f" codeflash --pr <pr-number> to optimize a PR{LF}"
|
||||
f"-or-{LF}"
|
||||
f" codeflash --help to see all options{LF}"
|
||||
f" codeflash --help to see all options{LF}",
|
||||
)
|
||||
if did_add_new_key:
|
||||
click.echo(
|
||||
"🐚 Don't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
|
||||
"🐚 Don't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!",
|
||||
)
|
||||
click.echo("Or run the following command to reload:")
|
||||
if os.name == "nt":
|
||||
|
|
@ -101,7 +101,7 @@ def collect_setup_info() -> SetupInfo:
|
|||
# Check if the cwd is writable
|
||||
if not os.access(curdir, os.W_OK):
|
||||
click.echo(
|
||||
f"❌ The current directory isn't writable, please check your folder permissions and try again.{LF}"
|
||||
f"❌ The current directory isn't writable, please check your folder permissions and try again.{LF}",
|
||||
)
|
||||
click.echo("It's likely you don't have write permissions for this folder.")
|
||||
sys.exit(1)
|
||||
|
|
@ -132,7 +132,7 @@ def collect_setup_info() -> SetupInfo:
|
|||
module_subdir_options = valid_module_subdirs + [curdir_option]
|
||||
|
||||
module_root_answer = inquirer.list_input(
|
||||
message=f"Which Python module do you want me to optimize going forward? (Usually the top-most directory with all of your Python source code)",
|
||||
message="Which Python module do you want me to optimize going forward? (Usually the top-most directory with all of your Python source code)",
|
||||
choices=module_subdir_options,
|
||||
default=(
|
||||
project_name if project_name in module_subdir_options else module_subdir_options[0]
|
||||
|
|
@ -172,7 +172,7 @@ def collect_setup_info() -> SetupInfo:
|
|||
exists=True,
|
||||
normalize_to_absolute_path=True,
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
tests_root = (
|
||||
custom_tests_root_answer["path"] if custom_tests_root_answer else apologize_and_exit()
|
||||
|
|
@ -224,7 +224,7 @@ def detect_test_framework(curdir, tests_root) -> Optional[str]:
|
|||
for pytest_file in pytest_files:
|
||||
file_path = os.path.join(curdir, pytest_file)
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, "r", encoding="utf8") as file:
|
||||
with open(file_path, encoding="utf8") as file:
|
||||
contents = file.read()
|
||||
if pytest_config_patterns[pytest_file] in contents:
|
||||
test_framework = "pytest"
|
||||
|
|
@ -234,7 +234,7 @@ def detect_test_framework(curdir, tests_root) -> Optional[str]:
|
|||
# Check if any python files contain a class that inherits from unittest.TestCase
|
||||
for filename in os.listdir(tests_root):
|
||||
if filename.endswith(".py"):
|
||||
with open(os.path.join(tests_root, filename), "r", encoding="utf8") as file:
|
||||
with open(os.path.join(tests_root, filename), encoding="utf8") as file:
|
||||
contents = file.read()
|
||||
try:
|
||||
node = ast.parse(contents)
|
||||
|
|
@ -265,38 +265,38 @@ def check_for_toml_or_setup_file() -> Optional[str]:
|
|||
project_name = None
|
||||
if os.path.exists(pyproject_toml_path):
|
||||
try:
|
||||
with open(pyproject_toml_path, "r", encoding="utf8") as f:
|
||||
with open(pyproject_toml_path, encoding="utf8") as f:
|
||||
pyproject_toml_content = f.read()
|
||||
project_name = tomlkit.parse(pyproject_toml_content)["tool"]["poetry"]["name"]
|
||||
click.echo(f"✅ I found a pyproject.toml for your project {project_name}.")
|
||||
ph("cli-pyproject-toml-found-name")
|
||||
except Exception as e:
|
||||
click.echo(f"✅ I found a pyproject.toml for your project.")
|
||||
except Exception:
|
||||
click.echo("✅ I found a pyproject.toml for your project.")
|
||||
ph("cli-pyproject-toml-found")
|
||||
else:
|
||||
if os.path.exists(setup_py_path):
|
||||
with open(setup_py_path, "r", encoding="utf8") as f:
|
||||
with open(setup_py_path, encoding="utf8") as f:
|
||||
setup_py_content = f.read()
|
||||
project_name_match = re.search(
|
||||
r"setup\s*\([^)]*?name\s*=\s*['\"](.*?)['\"]", setup_py_content, re.DOTALL
|
||||
r"setup\s*\([^)]*?name\s*=\s*['\"](.*?)['\"]", setup_py_content, re.DOTALL,
|
||||
)
|
||||
if project_name_match:
|
||||
project_name = project_name_match.group(1)
|
||||
click.echo(f"✅ Found setup.py for your project {project_name}")
|
||||
ph("cli-setup-py-found-name")
|
||||
else:
|
||||
click.echo(f"✅ Found setup.py.")
|
||||
click.echo("✅ Found setup.py.")
|
||||
ph("cli-setup-py-found")
|
||||
click.echo(
|
||||
f"💡 I couldn't find a pyproject.toml in the current directory ({curdir}).{LF}"
|
||||
f"(make sure you're running `codeflash init` from your project's root directory!){LF}"
|
||||
f"I need this file to store my configuration settings."
|
||||
f"I need this file to store my configuration settings.",
|
||||
)
|
||||
ph("cli-no-pyproject-toml-or-setup-py")
|
||||
|
||||
# Create a pyproject.toml file because it doesn't exist
|
||||
create_toml = inquirer.confirm(
|
||||
f"Do you want me to create a pyproject.toml file in the current directory?",
|
||||
"Do you want me to create a pyproject.toml file in the current directory?",
|
||||
default=True,
|
||||
show_default=False,
|
||||
)
|
||||
|
|
@ -314,9 +314,9 @@ def check_for_toml_or_setup_file() -> Optional[str]:
|
|||
click.echo(f"✅ Created a pyproject.toml file at {pyproject_toml_path}")
|
||||
click.pause()
|
||||
ph("cli-created-pyproject-toml")
|
||||
except IOError as e:
|
||||
except OSError:
|
||||
click.echo(
|
||||
"❌ Failed to create pyproject.toml. Please check your disk permissions and available space."
|
||||
"❌ Failed to create pyproject.toml. Please check your disk permissions and available space.",
|
||||
)
|
||||
apologize_and_exit()
|
||||
else:
|
||||
|
|
@ -328,7 +328,7 @@ def check_for_toml_or_setup_file() -> Optional[str]:
|
|||
|
||||
def apologize_and_exit() -> NoReturn:
|
||||
click.echo(
|
||||
"💡 If you're having trouble, see https://app.codeflash.ai/app/getting-started for further help getting started with Codeflash!"
|
||||
"💡 If you're having trouble, see https://app.codeflash.ai/app/getting-started for further help getting started with Codeflash!",
|
||||
)
|
||||
click.echo("👋 Exiting...")
|
||||
sys.exit(1)
|
||||
|
|
@ -363,10 +363,10 @@ def prompt_github_action(setup_info: SetupInfo) -> None:
|
|||
python_version_string = f" {py_version.major}.{py_version.minor}"
|
||||
|
||||
optimize_yml_content = read_text(
|
||||
"codeflash.cli_cmds.workflows", "codeflash-optimize.yaml"
|
||||
"codeflash.cli_cmds.workflows", "codeflash-optimize.yaml",
|
||||
)
|
||||
optimize_yml_content = optimize_yml_content.replace(
|
||||
" {{ python_version }}", python_version_string
|
||||
" {{ python_version }}", python_version_string,
|
||||
)
|
||||
with open(optimize_yaml_path, "w", encoding="utf8") as optimize_yml_file:
|
||||
optimize_yml_file.write(optimize_yml_content)
|
||||
|
|
@ -400,12 +400,12 @@ def prompt_github_action(setup_info: SetupInfo) -> None:
|
|||
click.launch(optimize_yaml_path)
|
||||
click.echo(
|
||||
"📝 I opened the workflow file in your editor! You'll need to edit the steps that install the right Python "
|
||||
+ f"version and any project dependencies. See the comments in the file for more details.{LF}"
|
||||
+ f"version and any project dependencies. See the comments in the file for more details.{LF}",
|
||||
)
|
||||
click.pause()
|
||||
click.echo()
|
||||
click.echo(
|
||||
f"🚀 Codeflash is now configured to automatically optimize new Github PRs!{LF}"
|
||||
f"🚀 Codeflash is now configured to automatically optimize new Github PRs!{LF}",
|
||||
)
|
||||
ph("cli-github-workflow-created")
|
||||
else:
|
||||
|
|
@ -417,18 +417,18 @@ def prompt_github_action(setup_info: SetupInfo) -> None:
|
|||
def configure_pyproject_toml(setup_info: SetupInfo) -> None:
|
||||
toml_path = os.path.join(os.getcwd(), "pyproject.toml")
|
||||
try:
|
||||
with open(toml_path, "r", encoding="utf8") as pyproject_file:
|
||||
with open(toml_path, encoding="utf8") as pyproject_file:
|
||||
pyproject_data = tomlkit.parse(pyproject_file.read())
|
||||
except FileNotFoundError:
|
||||
click.echo(
|
||||
f"I couldn't find a pyproject.toml in the current directory.{LF}"
|
||||
f"Please create a new empty pyproject.toml file here, OR if you use poetry then run `poetry init`, OR run `codeflash init` again from a directory with an existing pyproject.toml file."
|
||||
f"Please create a new empty pyproject.toml file here, OR if you use poetry then run `poetry init`, OR run `codeflash init` again from a directory with an existing pyproject.toml file.",
|
||||
)
|
||||
apologize_and_exit()
|
||||
|
||||
codeflash_section = tomlkit.table()
|
||||
codeflash_section.add(
|
||||
tomlkit.comment("All paths are relative to this pyproject.toml's directory.")
|
||||
tomlkit.comment("All paths are relative to this pyproject.toml's directory."),
|
||||
)
|
||||
codeflash_section["module-root"] = setup_info.module_root
|
||||
codeflash_section["tests-root"] = setup_info.tests_root
|
||||
|
|
@ -440,7 +440,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
|
|||
tool_section["codeflash"] = codeflash_section
|
||||
pyproject_data["tool"] = tool_section
|
||||
|
||||
click.echo(f"Writing Codeflash configuration ...\r", nl=False)
|
||||
click.echo("Writing Codeflash configuration ...\r", nl=False)
|
||||
with open(toml_path, "w", encoding="utf8") as pyproject_file:
|
||||
pyproject_file.write(tomlkit.dumps(pyproject_data))
|
||||
click.echo(f"✅ Added Codeflash configuration to {toml_path}")
|
||||
|
|
@ -466,7 +466,7 @@ class CFAPIKeyType(click.ParamType):
|
|||
def prompt_api_key() -> bool:
|
||||
try:
|
||||
existing_api_key = get_codeflash_api_key()
|
||||
except EnvironmentError:
|
||||
except OSError:
|
||||
existing_api_key = None
|
||||
if existing_api_key:
|
||||
display_key = f"{existing_api_key[:3]}****{existing_api_key[-4:]}"
|
||||
|
|
@ -499,14 +499,13 @@ def enter_api_key_and_save_to_rc() -> None:
|
|||
).strip()
|
||||
if api_key:
|
||||
break
|
||||
else:
|
||||
if not browser_launched:
|
||||
click.echo(
|
||||
f"Opening your Codeflash API key page. Grab a key from there!{LF}"
|
||||
"You can also open this link manually: https://app.codeflash.ai/app/apikeys"
|
||||
)
|
||||
click.launch("https://app.codeflash.ai/app/apikeys")
|
||||
browser_launched = True # This does not work on remote consoles
|
||||
elif not browser_launched:
|
||||
click.echo(
|
||||
f"Opening your Codeflash API key page. Grab a key from there!{LF}"
|
||||
"You can also open this link manually: https://app.codeflash.ai/app/apikeys",
|
||||
)
|
||||
click.launch("https://app.codeflash.ai/app/apikeys")
|
||||
browser_launched = True # This does not work on remote consoles
|
||||
shell_rc_path = get_shell_rc_path()
|
||||
if not shell_rc_path.exists() and os.name == "nt":
|
||||
# On Windows, create a batch file in the user's home directory (not auto-run, just used to store api key)
|
||||
|
|
@ -581,5 +580,5 @@ def run_end_to_end_test(setup_info: SetupInfo) -> None:
|
|||
click.echo(f"{LF}✅ End-to-end test passed. Codeflash has been correctly set up!")
|
||||
else:
|
||||
click.echo(
|
||||
f"{LF}❌ End-to-end test failed. Please check the logs above, and take a look at https://app.codeflash.ai/app/getting-started for help and troubleshooting."
|
||||
f"{LF}❌ End-to-end test failed. Please check the logs above, and take a look at https://app.codeflash.ai/app/getting-started for help and troubleshooting.",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,25 +11,25 @@ def find_pyproject_toml(config_file=None):
|
|||
if config_file is not None:
|
||||
if not config_file.lower().endswith(".toml"):
|
||||
raise ValueError(
|
||||
f"Config file {config_file} is not a valid toml file. Please recheck the path to pyproject.toml"
|
||||
f"Config file {config_file} is not a valid toml file. Please recheck the path to pyproject.toml",
|
||||
)
|
||||
if not os.path.exists(config_file):
|
||||
raise ValueError(
|
||||
f"Config file {config_file} does not exist. Please recheck the path to pyproject.toml"
|
||||
f"Config file {config_file} does not exist. Please recheck the path to pyproject.toml",
|
||||
)
|
||||
return config_file
|
||||
|
||||
else:
|
||||
dir_path = os.getcwd()
|
||||
|
||||
while not os.path.dirname(dir_path) == dir_path:
|
||||
while os.path.dirname(dir_path) != dir_path:
|
||||
config_file = os.path.join(dir_path, "pyproject.toml")
|
||||
if os.path.exists(config_file):
|
||||
return config_file
|
||||
# Search for pyproject.toml in the parent directories
|
||||
dir_path = os.path.dirname(dir_path)
|
||||
raise ValueError(
|
||||
f"Could not find pyproject.toml in the current directory {os.getcwd()} or any of the parent directories. Please create it by running `poetry init`, or pass the path to pyproject.toml with the --config-file argument."
|
||||
f"Could not find pyproject.toml in the current directory {os.getcwd()} or any of the parent directories. Please create it by running `poetry init`, or pass the path to pyproject.toml with the --config-file argument.",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -40,7 +40,7 @@ def parse_config_file(config_file_path=None):
|
|||
data = tomlkit.parse(f.read())
|
||||
except tomlkit.exceptions.ParseError as e:
|
||||
raise ValueError(
|
||||
f"Error while parsing the config file {config_file_path}. Please recheck the file for syntax errors. Error: {e}"
|
||||
f"Error while parsing the config file {config_file_path}. Please recheck the file for syntax errors. Error: {e}",
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -50,7 +50,7 @@ def parse_config_file(config_file_path=None):
|
|||
except tomlkit.exceptions.NonExistentKey:
|
||||
raise ValueError(
|
||||
f"Could not find the 'codeflash' block in the config file {config_file_path}. "
|
||||
f"Please run 'codeflash init' to create the config file."
|
||||
f"Please run 'codeflash init' to create the config file.",
|
||||
)
|
||||
assert isinstance(config, dict)
|
||||
|
||||
|
|
@ -88,7 +88,7 @@ def parse_config_file(config_file_path=None):
|
|||
for key in path_keys:
|
||||
if key in config:
|
||||
config[key] = os.path.realpath(
|
||||
os.path.join(os.path.dirname(config_file_path), config[key])
|
||||
os.path.join(os.path.dirname(config_file_path), config[key]),
|
||||
)
|
||||
|
||||
for key in path_list_keys:
|
||||
|
|
|
|||
|
|
@ -14,16 +14,16 @@ from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
|
|||
def get_codeflash_api_key() -> Optional[str]:
|
||||
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
|
||||
if not api_key:
|
||||
raise EnvironmentError(
|
||||
raise OSError(
|
||||
"I didn't find a Codeflash API key in your environment.\n"
|
||||
+ "You can generate one at https://app.codeflash.ai/app/apikeys,\n"
|
||||
+ "then set it as a CODEFLASH_API_KEY environment variable."
|
||||
+ "then set it as a CODEFLASH_API_KEY environment variable.",
|
||||
)
|
||||
if not api_key.startswith("cf-"):
|
||||
raise EnvironmentError(
|
||||
raise OSError(
|
||||
f"Your Codeflash API key seems to be invalid. It should start with a 'cf-' prefix; I found '{api_key}' instead.\n"
|
||||
+ "You can generate one at https://app.codeflash.ai/app/apikeys,\n"
|
||||
+ "then set it as a CODEFLASH_API_KEY environment variable."
|
||||
+ "then set it as a CODEFLASH_API_KEY environment variable.",
|
||||
)
|
||||
return api_key
|
||||
|
||||
|
|
@ -31,11 +31,11 @@ def get_codeflash_api_key() -> Optional[str]:
|
|||
def ensure_codeflash_api_key() -> bool:
|
||||
try:
|
||||
get_codeflash_api_key()
|
||||
except EnvironmentError:
|
||||
logging.error(
|
||||
except OSError:
|
||||
logging.exception(
|
||||
"Codeflash API key not found in your environment.\n"
|
||||
+ "You can generate one at https://app.codeflash.ai/app/apikeys,\n"
|
||||
+ "then set it as a CODEFLASH_API_KEY environment variable."
|
||||
+ "then set it as a CODEFLASH_API_KEY environment variable.",
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
|
@ -58,8 +58,8 @@ def get_pr_number() -> Optional[int]:
|
|||
|
||||
def ensure_pr_number() -> bool:
|
||||
if not get_pr_number():
|
||||
raise EnvironmentError(
|
||||
"CODEFLASH_PR_NUMBER not found in environment variables; make sure the Github Action is setting this so Codeflash can comment on the right PR"
|
||||
raise OSError(
|
||||
"CODEFLASH_PR_NUMBER not found in environment variables; make sure the Github Action is setting this so Codeflash can comment on the right PR",
|
||||
)
|
||||
return True
|
||||
|
||||
|
|
|
|||
|
|
@ -1,18 +1,19 @@
|
|||
import logging
|
||||
import os.path
|
||||
import subprocess
|
||||
|
||||
import isort
|
||||
|
||||
|
||||
def format_code(
|
||||
formatter_cmd: str, imports_sort_cmd: str, should_sort_imports: bool, path: str
|
||||
formatter_cmd: str, imports_sort_cmd: str, should_sort_imports: bool, path: str,
|
||||
) -> str:
|
||||
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
|
||||
if formatter_cmd.lower() == "disabled":
|
||||
if should_sort_imports:
|
||||
return sort_imports(imports_sort_cmd, should_sort_imports, path)
|
||||
|
||||
with open(path, "r", encoding="utf8") as f:
|
||||
with open(path, encoding="utf8") as f:
|
||||
new_code = f.read()
|
||||
return new_code
|
||||
|
||||
|
|
@ -22,7 +23,7 @@ def format_code(
|
|||
# the main problem is custom config parsing https://github.com/psf/black/issues/779
|
||||
assert os.path.exists(path), f"File {path} does not exist. Cannot format the file. Exiting..."
|
||||
result = subprocess.run(
|
||||
formatter_cmd_list + [path], stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
formatter_cmd_list + [path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logging.info("OK")
|
||||
|
|
@ -32,7 +33,7 @@ def format_code(
|
|||
if should_sort_imports:
|
||||
return sort_imports(imports_sort_cmd, should_sort_imports, path)
|
||||
|
||||
with open(path, "r", encoding="utf8") as f:
|
||||
with open(path, encoding="utf8") as f:
|
||||
new_code = f.read()
|
||||
|
||||
return new_code
|
||||
|
|
@ -40,7 +41,7 @@ def format_code(
|
|||
|
||||
def sort_imports(imports_sort_cmd: str, should_sort_imports: bool, path: str) -> str:
|
||||
if imports_sort_cmd.lower() == "disabled":
|
||||
with open(path, "r", encoding="utf8") as f:
|
||||
with open(path, encoding="utf8") as f:
|
||||
code = f.read()
|
||||
return code
|
||||
|
||||
|
|
@ -48,11 +49,11 @@ def sort_imports(imports_sort_cmd: str, should_sort_imports: bool, path: str) ->
|
|||
# Deduplicate and sort imports
|
||||
isort.file(path)
|
||||
|
||||
with open(path, "r", encoding="utf8") as f:
|
||||
with open(path, encoding="utf8") as f:
|
||||
new_code = f.read()
|
||||
return new_code
|
||||
else:
|
||||
# Return original code
|
||||
with open(path, "r", encoding="utf8") as f:
|
||||
with open(path, encoding="utf8") as f:
|
||||
code = f.read()
|
||||
return code
|
||||
|
|
|
|||
|
|
@ -9,13 +9,13 @@ from unidiff import PatchSet
|
|||
|
||||
|
||||
def get_git_diff(
|
||||
repo_directory: str = os.getcwd(), uncommitted_changes: bool = False
|
||||
repo_directory: str = os.getcwd(), uncommitted_changes: bool = False,
|
||||
) -> dict[str, list[int]]:
|
||||
repository = git.Repo(repo_directory, search_parent_directories=True)
|
||||
commit = repository.head.commit
|
||||
if uncommitted_changes:
|
||||
uni_diff_text = repository.git.diff(
|
||||
None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True
|
||||
None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True,
|
||||
)
|
||||
else:
|
||||
uni_diff_text = repository.git.diff(
|
||||
|
|
@ -54,8 +54,7 @@ def get_git_diff(
|
|||
|
||||
|
||||
def get_current_branch(repo: Optional[Repo] = None) -> str:
|
||||
"""
|
||||
Returns the name of the current branch in the given repository.
|
||||
"""Returns the name of the current branch in the given repository.
|
||||
|
||||
:param repo: An optional Repo object. If not provided, the function will
|
||||
search for a repository in the current and parent directories.
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import ast
|
|||
from _ast import ClassDef
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from codeflash.code_utils.code_utils import module_name_from_file_path, get_run_tmp_file
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
|
||||
|
||||
|
||||
class ReplaceCallNodeWithName(ast.NodeTransformer):
|
||||
|
|
@ -26,7 +26,7 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
self.module_path = module_path
|
||||
|
||||
def update_line_node(
|
||||
self, test_node, node_name, index: str, test_class_name: Optional[str] = None
|
||||
self, test_node, node_name, index: str, test_class_name: Optional[str] = None,
|
||||
):
|
||||
call_node = None
|
||||
for node in ast.walk(test_node):
|
||||
|
|
@ -122,17 +122,17 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
ast.JoinedStr(
|
||||
values=[
|
||||
ast.Constant(
|
||||
value=f"{get_run_tmp_file('test_return_values_')}"
|
||||
value=f"{get_run_tmp_file('test_return_values_')}",
|
||||
),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(
|
||||
id="codeflash_iteration", ctx=ast.Load()
|
||||
id="codeflash_iteration", ctx=ast.Load(),
|
||||
),
|
||||
conversion=-1,
|
||||
),
|
||||
ast.Constant(value=".sqlite"),
|
||||
]
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
|
|
@ -164,8 +164,8 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
ast.Constant(
|
||||
value="CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT,"
|
||||
" test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT,"
|
||||
" iteration_id TEXT, runtime INTEGER, return_value BLOB)"
|
||||
)
|
||||
" iteration_id TEXT, runtime INTEGER, return_value BLOB)",
|
||||
),
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
|
|
@ -184,8 +184,8 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
),
|
||||
args=[],
|
||||
keywords=[],
|
||||
)
|
||||
)
|
||||
),
|
||||
),
|
||||
]
|
||||
)
|
||||
i = len(node.body) - 1
|
||||
|
|
@ -200,15 +200,14 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
for internal_node in ast.walk(compound_line_node):
|
||||
if self.is_target_function_line(internal_node):
|
||||
line_node.body[j : j + 1] = self.update_line_node(
|
||||
internal_node, node.name, str(i) + "_" + str(j), test_class_name
|
||||
internal_node, node.name, str(i) + "_" + str(j), test_class_name,
|
||||
)
|
||||
break
|
||||
j -= 1
|
||||
else:
|
||||
if self.is_target_function_line(line_node):
|
||||
node.body[i : i + 1] = self.update_line_node(
|
||||
line_node, node.name, str(i), test_class_name
|
||||
)
|
||||
elif self.is_target_function_line(line_node):
|
||||
node.body[i : i + 1] = self.update_line_node(
|
||||
line_node, node.name, str(i), test_class_name,
|
||||
)
|
||||
i -= 1
|
||||
return node
|
||||
|
||||
|
|
@ -216,7 +215,8 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
class FunctionImportedAsVisitor(ast.NodeVisitor):
|
||||
"""This checks if a function has been imported as an alias. We only care about the alias then.
|
||||
from numpy import array as np_array
|
||||
np_array is what we want"""
|
||||
np_array is what we want
|
||||
"""
|
||||
|
||||
def __init__(self, original_function_name):
|
||||
self.original_function_name = original_function_name
|
||||
|
|
@ -226,12 +226,12 @@ class FunctionImportedAsVisitor(ast.NodeVisitor):
|
|||
def visit_ImportFrom(self, node: ast.ImportFrom):
|
||||
for alias in node.names:
|
||||
if alias.name == self.original_function_name:
|
||||
if hasattr(alias, "asname") and not alias.asname is None:
|
||||
if hasattr(alias, "asname") and alias.asname is not None:
|
||||
self.imported_as_function_name = alias.asname
|
||||
|
||||
|
||||
def inject_profiling_into_existing_test(test_path, function_name, root_path) -> Tuple[bool, str]:
|
||||
with open(test_path, "r", encoding="utf8") as f:
|
||||
with open(test_path, encoding="utf8") as f:
|
||||
test_code = f.read()
|
||||
try:
|
||||
tree = ast.parse(test_code)
|
||||
|
|
@ -284,21 +284,21 @@ def create_wrapper_function(function_name, module_path):
|
|||
value=ast.JoinedStr(
|
||||
values=[
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1
|
||||
value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1,
|
||||
),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="test_class_name", ctx=ast.Load()), conversion=-1
|
||||
value=ast.Name(id="test_class_name", ctx=ast.Load()), conversion=-1,
|
||||
),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1
|
||||
value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1,
|
||||
),
|
||||
ast.Constant(value=":"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1
|
||||
value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1,
|
||||
),
|
||||
]
|
||||
],
|
||||
),
|
||||
lineno=lineno + 1,
|
||||
),
|
||||
|
|
@ -321,11 +321,11 @@ def create_wrapper_function(function_name, module_path):
|
|||
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()),
|
||||
attr="index",
|
||||
ctx=ast.Store(),
|
||||
)
|
||||
),
|
||||
],
|
||||
value=ast.Dict(keys=[], values=[]),
|
||||
lineno=lineno + 3,
|
||||
)
|
||||
),
|
||||
],
|
||||
orelse=[],
|
||||
lineno=lineno + 2,
|
||||
|
|
@ -339,7 +339,7 @@ def create_wrapper_function(function_name, module_path):
|
|||
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()),
|
||||
attr="index",
|
||||
ctx=ast.Load(),
|
||||
)
|
||||
),
|
||||
],
|
||||
),
|
||||
body=[
|
||||
|
|
@ -356,7 +356,7 @@ def create_wrapper_function(function_name, module_path):
|
|||
op=ast.Add(),
|
||||
value=ast.Constant(value=1),
|
||||
lineno=lineno + 5,
|
||||
)
|
||||
),
|
||||
],
|
||||
orelse=[
|
||||
ast.Assign(
|
||||
|
|
@ -369,11 +369,11 @@ def create_wrapper_function(function_name, module_path):
|
|||
),
|
||||
slice=ast.Name(id="test_id", ctx=ast.Load()),
|
||||
ctx=ast.Store(),
|
||||
)
|
||||
),
|
||||
],
|
||||
value=ast.Constant(value=0),
|
||||
lineno=lineno + 6,
|
||||
)
|
||||
),
|
||||
],
|
||||
lineno=lineno + 4,
|
||||
),
|
||||
|
|
@ -397,13 +397,13 @@ def create_wrapper_function(function_name, module_path):
|
|||
value=ast.JoinedStr(
|
||||
values=[
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1
|
||||
value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1,
|
||||
),
|
||||
ast.Constant(value="_"),
|
||||
ast.FormattedValue(
|
||||
value=ast.Name(id="codeflash_test_index", ctx=ast.Load()), conversion=-1
|
||||
value=ast.Name(id="codeflash_test_index", ctx=ast.Load()), conversion=-1,
|
||||
),
|
||||
]
|
||||
],
|
||||
),
|
||||
lineno=lineno + 8,
|
||||
),
|
||||
|
|
|
|||
|
|
@ -9,16 +9,16 @@ from codeflash.code_utils.compat import LF
|
|||
|
||||
if os.name == "nt": # Windows
|
||||
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.M)
|
||||
SHELL_RC_EXPORT_PREFIX = f"set CODEFLASH_API_KEY="
|
||||
SHELL_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
|
||||
else:
|
||||
SHELL_RC_EXPORT_PATTERN = re.compile(r'^export CODEFLASH_API_KEY="?(cf-[^\s"]+)"?$', re.M)
|
||||
SHELL_RC_EXPORT_PREFIX = f"export CODEFLASH_API_KEY="
|
||||
SHELL_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
|
||||
|
||||
|
||||
def read_api_key_from_shell_config() -> Optional[str]:
|
||||
try:
|
||||
shell_rc_path = get_shell_rc_path()
|
||||
with open(shell_rc_path, "r", encoding="utf8") as shell_rc:
|
||||
with open(shell_rc_path, encoding="utf8") as shell_rc:
|
||||
shell_contents = shell_rc.read()
|
||||
matches = SHELL_RC_EXPORT_PATTERN.findall(shell_contents)
|
||||
return matches[-1] if matches else None
|
||||
|
|
@ -39,7 +39,7 @@ def get_shell_rc_path() -> Path:
|
|||
"tcsh": ".cshrc",
|
||||
"dash": ".profile",
|
||||
}.get(
|
||||
shell, ".bashrc"
|
||||
shell, ".bashrc",
|
||||
) # map each shell to its config file and default to .bashrc
|
||||
return Path.home() / shell_rc_filename
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ def save_api_key_to_rc(api_key) -> Result[str, str]:
|
|||
if existing_api_key:
|
||||
# Replace the existing API key line
|
||||
updated_shell_contents = re.sub(
|
||||
SHELL_RC_EXPORT_PATTERN, api_key_line, shell_contents
|
||||
SHELL_RC_EXPORT_PATTERN, api_key_line, shell_contents,
|
||||
)
|
||||
action = "Updated CODEFLASH_API_KEY in"
|
||||
else:
|
||||
|
|
@ -77,7 +77,7 @@ def save_api_key_to_rc(api_key) -> Result[str, str]:
|
|||
except PermissionError:
|
||||
return Failure(
|
||||
f"💡 I tried adding your Codeflash API key to {shell_rc_path} - but seems like I don't have permissions to do so.{LF}"
|
||||
f"You'll need to open it yourself and add the following line:{LF}{LF}{api_key_line}{LF}"
|
||||
f"You'll need to open it yourself and add the following line:{LF}{LF}{api_key_line}{LF}",
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return Failure(
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class TestsInFile:
|
|||
def from_pytest_stdout_line_co(cls, module: str, function: str, directory: str):
|
||||
absolute_test_path = os.path.join(directory, module)
|
||||
assert os.path.exists(
|
||||
absolute_test_path
|
||||
absolute_test_path,
|
||||
), f"Test discovery failed - Test file does not exist {absolute_test_path}"
|
||||
return cls(
|
||||
test_file=absolute_test_path,
|
||||
|
|
@ -44,7 +44,7 @@ class TestsInFile:
|
|||
parts = line.split("::")
|
||||
absolute_test_path = os.path.join(pytest_rootdir, parts[0])
|
||||
assert os.path.exists(
|
||||
absolute_test_path
|
||||
absolute_test_path,
|
||||
), f"Test discovery failed - Test file does not exist {absolute_test_path}"
|
||||
if len(parts) == 3:
|
||||
return cls(
|
||||
|
|
@ -121,7 +121,7 @@ def discover_tests_pytest(cfg: TestConfig) -> Dict[str, List[TestsInFile]]:
|
|||
pytest_result = subprocess.run(
|
||||
pytest_cmd_list + [f"{tests_root}", "--co", "-q", "-m", "not skip"],
|
||||
stdout=subprocess.PIPE,
|
||||
cwd=project_root,
|
||||
cwd=project_root, check=False,
|
||||
)
|
||||
|
||||
pytest_stdout = pytest_result.stdout.decode("utf-8")
|
||||
|
|
@ -185,7 +185,7 @@ def discover_tests_unittest(cfg: TestConfig) -> Dict[str, List[TestsInFile]]:
|
|||
{
|
||||
"test_function": details.test_function,
|
||||
"test_suite_name": details.test_suite_name,
|
||||
}
|
||||
},
|
||||
)
|
||||
else:
|
||||
details = get_test_details(test)
|
||||
|
|
@ -194,7 +194,7 @@ def discover_tests_unittest(cfg: TestConfig) -> Dict[str, List[TestsInFile]]:
|
|||
{
|
||||
"test_function": details.test_function,
|
||||
"test_suite_name": details.test_suite_name,
|
||||
}
|
||||
},
|
||||
)
|
||||
return process_test_files(file_to_test_map, cfg)
|
||||
|
||||
|
|
@ -208,7 +208,7 @@ def discover_parameters_unittest(function_name: str):
|
|||
|
||||
|
||||
def process_test_files(
|
||||
file_to_test_map: Dict[str, List[Dict[str, str]]], cfg: TestConfig
|
||||
file_to_test_map: Dict[str, List[Dict[str, str]]], cfg: TestConfig,
|
||||
) -> Dict[str, List[TestsInFile]]:
|
||||
project_root_path = cfg.project_root_path
|
||||
test_framework = cfg.test_framework
|
||||
|
|
@ -231,10 +231,9 @@ def process_test_files(
|
|||
parameters = re.split(r"\[|\]", function)[1]
|
||||
if name.name == function_name and name.type == "function":
|
||||
test_functions.add(TestFunction(name.name, None, parameters))
|
||||
else:
|
||||
if name.name == function and name.type == "function":
|
||||
test_functions.add(TestFunction(name.name, None, None))
|
||||
break
|
||||
elif name.name == function and name.type == "function":
|
||||
test_functions.add(TestFunction(name.name, None, None))
|
||||
break
|
||||
if test_framework == "unittest":
|
||||
functions_to_search = [elem["test_function"] for elem in functions]
|
||||
test_suites = [elem["test_suite_name"] for elem in functions]
|
||||
|
|
@ -254,7 +253,7 @@ def process_test_files(
|
|||
|
||||
if is_parameterized and new_function == def_name.name:
|
||||
test_functions.add(
|
||||
TestFunction(def_name.name, name.name, parameters)
|
||||
TestFunction(def_name.name, name.name, parameters),
|
||||
)
|
||||
elif function == def_name.name:
|
||||
test_functions.add(TestFunction(def_name.name, name.name, None))
|
||||
|
|
@ -282,7 +281,7 @@ def process_test_files(
|
|||
follow_builtin_imports=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(str(e))
|
||||
logging.exception(str(e))
|
||||
continue
|
||||
if definition and definition[0].type == "function":
|
||||
definition_path = str(definition[0].module_path)
|
||||
|
|
@ -298,7 +297,7 @@ def process_test_files(
|
|||
scope_test_function += "_" + scope_parameters
|
||||
|
||||
function_to_test_map[definition[0].full_name].append(
|
||||
TestsInFile(test_file, None, scope_test_function, scope_test_suite)
|
||||
TestsInFile(test_file, None, scope_test_function, scope_test_suite),
|
||||
)
|
||||
deduped_function_to_test_map = {}
|
||||
for function, tests in function_to_test_map.items():
|
||||
|
|
@ -307,7 +306,7 @@ def process_test_files(
|
|||
|
||||
|
||||
def parse_pytest_stdout(
|
||||
pytest_stdout: str, pytest_rootdir: str, tests_root: str, parse_type: ParseType
|
||||
pytest_stdout: str, pytest_rootdir: str, tests_root: str, parse_type: ParseType,
|
||||
) -> List[TestsInFile]:
|
||||
test_results = []
|
||||
if parse_type == ParseType.Q:
|
||||
|
|
@ -384,7 +383,7 @@ def parse_pytest_stdout(
|
|||
function = function.group(1)
|
||||
try:
|
||||
test_result = TestsInFile.from_pytest_stdout_line_co(
|
||||
module, function, directory
|
||||
module, function, directory,
|
||||
)
|
||||
test_results.append(test_result)
|
||||
except ValueError as e:
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class FunctionVisitor(cst.CSTVisitor):
|
|||
parents=list(reversed(ast_parents)),
|
||||
starting_line=pos.start.line,
|
||||
ending_line=pos.end.line,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -66,8 +66,8 @@ class FunctionWithReturnStatement(ast.NodeVisitor):
|
|||
if function_has_return_statement(node):
|
||||
self.functions.append(
|
||||
FunctionToOptimize(
|
||||
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:]
|
||||
)
|
||||
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:],
|
||||
),
|
||||
)
|
||||
# Continue visiting the body of the function to find nested functions
|
||||
self.generic_visit(node)
|
||||
|
|
@ -145,14 +145,14 @@ def get_functions_to_optimize_by_file(
|
|||
if found_function is None:
|
||||
raise ValueError(
|
||||
f"Function {only_function_name} not found in file {file} or"
|
||||
f" the function does not have a 'return' statement."
|
||||
f" the function does not have a 'return' statement.",
|
||||
)
|
||||
functions[file] = [found_function]
|
||||
else:
|
||||
logging.info("Finding all functions modified in the current git diff ...")
|
||||
functions = get_functions_within_git_diff()
|
||||
filtered_modified_functions, functions_count = filter_functions(
|
||||
functions, test_cfg.tests_root, ignore_paths, project_root, module_root
|
||||
functions, test_cfg.tests_root, ignore_paths, project_root, module_root,
|
||||
)
|
||||
logging.info("Found %d functions to optimize", functions_count)
|
||||
return filtered_modified_functions, functions_count
|
||||
|
|
@ -164,12 +164,12 @@ def get_functions_within_git_diff() -> Dict[str, List[FunctionToOptimize]]:
|
|||
for path in modified_lines:
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
with open(path, "r", encoding="utf8") as f:
|
||||
with open(path, encoding="utf8") as f:
|
||||
file_content = f.read()
|
||||
try:
|
||||
wrapper = cst.metadata.MetadataWrapper(cst.parse_module(file_content))
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
logging.exception(e)
|
||||
continue
|
||||
function_lines = FunctionVisitor(file_path=path)
|
||||
wrapper.visit(function_lines)
|
||||
|
|
@ -203,11 +203,11 @@ def get_all_files_and_functions(module_root_path: str) -> Dict[str, List[Functio
|
|||
|
||||
def find_all_functions_in_file(file_path: str) -> Dict[str, List[FunctionToOptimize]]:
|
||||
functions: Dict[str, List[FunctionToOptimize]] = {}
|
||||
with open(file_path, "r", encoding="utf8") as f:
|
||||
with open(file_path, encoding="utf8") as f:
|
||||
try:
|
||||
ast_module = ast.parse(f.read())
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
logging.exception(e)
|
||||
return functions
|
||||
function_name_visitor = FunctionWithReturnStatement(file_path)
|
||||
function_name_visitor.visit(ast_module)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Union
|
||||
|
||||
from codeflash.verification.test_results import TestResults
|
||||
from pydantic import BaseModel
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.verification.test_results import TestResults
|
||||
|
||||
|
||||
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
|
||||
class PrComment:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from posthog import Posthog
|
||||
|
||||
|
|
@ -10,8 +10,7 @@ _posthog = None
|
|||
|
||||
|
||||
def initialize_posthog(enabled: bool) -> None:
|
||||
"""
|
||||
Enable or disable PostHog.
|
||||
"""Enable or disable PostHog.
|
||||
:param enabled: Whether to enable PostHog.
|
||||
"""
|
||||
if not enabled:
|
||||
|
|
@ -26,8 +25,7 @@ def initialize_posthog(enabled: bool) -> None:
|
|||
|
||||
|
||||
def ph(event: str, properties: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""
|
||||
Log an event to PostHog.
|
||||
"""Log an event to PostHog.
|
||||
:param event: The name of the event.
|
||||
:param properties: A dictionary of properties to attach to the event.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Any, Dict, Generator, List, Tuple
|
|||
|
||||
|
||||
def get_next_arg_and_return(
|
||||
trace_file: str, function_name: str, num_to_get: int = 3
|
||||
trace_file: str, function_name: str, num_to_get: int = 3,
|
||||
) -> Generator[Tuple[Any, Any], None, None]:
|
||||
db = sqlite3.connect(trace_file)
|
||||
cur = db.cursor()
|
||||
|
|
@ -27,14 +27,14 @@ def get_next_arg_and_return(
|
|||
matched_arg_return[frame_address].append(val[7])
|
||||
if len(matched_arg_return[frame_address]) > 1:
|
||||
logging.warning(
|
||||
f"Pre-existing call to the function {function_name} with same frame address."
|
||||
f"Pre-existing call to the function {function_name} with same frame address.",
|
||||
)
|
||||
elif event_type == "return":
|
||||
matched_arg_return[frame_address].append(val[6])
|
||||
arg_return_length = len(matched_arg_return[frame_address])
|
||||
if arg_return_length > 2:
|
||||
logging.warning(
|
||||
f"Pre-existing return to the function {function_name} with same frame address."
|
||||
f"Pre-existing return to the function {function_name} with same frame address.",
|
||||
)
|
||||
elif arg_return_length == 1:
|
||||
logging.warning(f"No call before return for {function_name}!")
|
||||
|
|
@ -51,7 +51,7 @@ def get_function_alias(module: str, function_name: str) -> str:
|
|||
|
||||
|
||||
def create_trace_replay_test(
|
||||
trace_file: str, functions: List[Tuple[str, str]], test_framework: str = "pytest"
|
||||
trace_file: str, functions: List[Tuple[str, str]], test_framework: str = "pytest",
|
||||
) -> str:
|
||||
assert test_framework in ["pytest", "unittest"]
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ def _create_unittest_trace_replay_test(trace_file: str, functions: List[Tuple[st
|
|||
return_val = pickle.loads(return_val_pkl)
|
||||
ret = {function_name}(**args)
|
||||
self.assertTrue(comparator(return_val, ret))
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n"
|
||||
|
|
@ -111,7 +111,7 @@ def _create_pytest_trace_replay_test(trace_file: str, functions: List[Tuple[str,
|
|||
return_val = pickle.loads(return_val_pkl)
|
||||
ret = {function_name}(**args)
|
||||
assert comparator(return_val, ret)
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
test_template = ""
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ from codeflash.verification.verification_utils import get_test_file_path
|
|||
|
||||
|
||||
class Tracer:
|
||||
"""
|
||||
Use this class as a 'with' context manager to trace a function call,
|
||||
"""Use this class as a 'with' context manager to trace a function call,
|
||||
input arguments, and return value.
|
||||
"""
|
||||
|
||||
|
|
@ -57,7 +56,7 @@ class Tracer:
|
|||
# TODO: Check out if we need to export the function test name as well
|
||||
cur.execute(
|
||||
"CREATE TABLE events(type TEXT, function TEXT, filename TEXT, line_number INTEGER, "
|
||||
"last_frame_address INTEGER, time_ns INTEGER, arg BLOB, locals BLOB)"
|
||||
"last_frame_address INTEGER, time_ns INTEGER, arg BLOB, locals BLOB)",
|
||||
)
|
||||
sys.setprofile(self.trace_callback)
|
||||
|
||||
|
|
@ -80,23 +79,23 @@ class Tracer:
|
|||
)
|
||||
function_path = "_".join([func for _, func in module_function])
|
||||
test_file_path = get_test_file_path(
|
||||
self.config["tests_root"], function_path, test_type="replay"
|
||||
self.config["tests_root"], function_path, test_type="replay",
|
||||
)
|
||||
with open(test_file_path, "w", encoding="utf8") as file:
|
||||
file.write(replay_test)
|
||||
|
||||
logging.info(
|
||||
f"Codeflash: Function Traced successfully and replay test created! Path - {test_file_path}"
|
||||
f"Codeflash: Function Traced successfully and replay test created! Path - {test_file_path}",
|
||||
)
|
||||
|
||||
def trace_callback(self, frame: Any, event: str, arg: Any) -> None:
|
||||
if event not in ["call", "return"]:
|
||||
return None
|
||||
return
|
||||
|
||||
code = frame.f_code
|
||||
if self.functions:
|
||||
if code.co_name not in self.functions:
|
||||
return None
|
||||
return
|
||||
if self.function_count[code.co_name] >= self.max_function_count:
|
||||
return
|
||||
self.function_count[code.co_name] += 1
|
||||
|
|
@ -114,7 +113,7 @@ class Tracer:
|
|||
|
||||
project_root = os.path.realpath(os.path.join(self.config["module_root"], ".."))
|
||||
self.function_modules[code.co_name] = module_name_from_file_path(
|
||||
code.co_filename, project_root=project_root
|
||||
code.co_filename, project_root=project_root,
|
||||
)
|
||||
cur = self.con.cursor()
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ def main():
|
|||
version_replacement = r"\g<1>" + major_minor_version + r".x"
|
||||
|
||||
# Read the LICENSE file
|
||||
with open("codeflash/LICENSE", "r", encoding="utf8") as file:
|
||||
with open("codeflash/LICENSE", encoding="utf8") as file:
|
||||
license_text = file.read()
|
||||
|
||||
# Replace the version in the LICENSE file
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ def comparator(orig: Any, new: Any) -> bool:
|
|||
return orig.equals(new)
|
||||
|
||||
if HAS_PANDAS and isinstance(
|
||||
orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)
|
||||
orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period),
|
||||
):
|
||||
return orig == new
|
||||
|
||||
|
|
@ -145,6 +145,6 @@ def comparator(orig: Any, new: Any) -> bool:
|
|||
logging.warning(f"Unknown comparator input type: {type(orig)}")
|
||||
return True
|
||||
except RecursionError as e:
|
||||
logging.error(f"RecursionError while comparing objects: {e}")
|
||||
logging.exception(f"RecursionError while comparing objects: {e}")
|
||||
sentry_sdk.capture_exception(e)
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -11,21 +11,21 @@ import sentry_sdk
|
|||
from junitparser.xunit2 import JUnitXml
|
||||
|
||||
from codeflash.code_utils.code_utils import (
|
||||
module_name_from_file_path,
|
||||
get_run_tmp_file,
|
||||
module_name_from_file_path,
|
||||
)
|
||||
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
|
||||
from codeflash.verification.test_results import (
|
||||
TestResults,
|
||||
FunctionTestInvocation,
|
||||
TestType,
|
||||
InvocationId,
|
||||
TestResults,
|
||||
TestType,
|
||||
)
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
def parse_test_return_values_bin(
|
||||
file_location: str, test_framework: str, test_type: TestType, test_file_path: str
|
||||
file_location: str, test_framework: str, test_type: TestType, test_file_path: str,
|
||||
) -> TestResults:
|
||||
test_results = TestResults()
|
||||
if not os.path.exists(file_location):
|
||||
|
|
@ -47,7 +47,7 @@ def parse_test_return_values_bin(
|
|||
try:
|
||||
test_pickle = pickle.loads(file.read(len_next))
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load pickle file. Exception: {e}")
|
||||
logging.exception(f"Failed to load pickle file. Exception: {e}")
|
||||
return test_results
|
||||
len_next = file.read(4)
|
||||
len_next = int.from_bytes(len_next, byteorder="big")
|
||||
|
|
@ -65,7 +65,7 @@ def parse_test_return_values_bin(
|
|||
test_framework=test_framework,
|
||||
test_type=test_type,
|
||||
return_value=test_pickle,
|
||||
)
|
||||
),
|
||||
)
|
||||
# Hardcoding the test result to True because the test did execute and we are only interested in the return values,
|
||||
# the did_pass comes from the xml results file
|
||||
|
|
@ -104,7 +104,7 @@ def parse_sqlite_test_results(
|
|||
test_framework=test_config.test_framework,
|
||||
test_type=test_type,
|
||||
return_value=None,
|
||||
)
|
||||
),
|
||||
)
|
||||
# return_value is only None temporarily as this is only being used for the existing tests. This should generalize
|
||||
# to read the return_value from the sqlite file as well.
|
||||
|
|
@ -136,7 +136,7 @@ def parse_test_xml(
|
|||
for testcase in suite:
|
||||
class_name = testcase.classname
|
||||
file_name = suite._elem.attrib.get(
|
||||
"file"
|
||||
"file",
|
||||
) # file_path_from_module_name(generated_tests_path, test_config.project_root_path)
|
||||
if (
|
||||
file_name == f"unittest{os.sep}loader.py"
|
||||
|
|
@ -145,10 +145,10 @@ def parse_test_xml(
|
|||
and suite.tests == 1
|
||||
):
|
||||
# This means that the test failed to load, so we don't want to crash on it
|
||||
logging.info(f"Test failed to load, skipping it.")
|
||||
logging.info("Test failed to load, skipping it.")
|
||||
if run_result is not None:
|
||||
logging.info(
|
||||
f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}"
|
||||
f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}",
|
||||
)
|
||||
return test_results
|
||||
file_name = test_py_file_path
|
||||
|
|
@ -170,7 +170,7 @@ def parse_test_xml(
|
|||
xml_file_contents = open(test_xml_file_path).read()
|
||||
scope.set_extra("file", xml_file_contents)
|
||||
sentry_sdk.capture_message(
|
||||
f"testcase.name is None in parse_test_xml for testcase {repr(testcase)} in file {xml_file_contents}"
|
||||
f"testcase.name is None in parse_test_xml for testcase {testcase!r} in file {xml_file_contents}",
|
||||
)
|
||||
continue
|
||||
# Parse test timing
|
||||
|
|
@ -193,19 +193,19 @@ def parse_test_xml(
|
|||
did_pass=result,
|
||||
test_type=test_type,
|
||||
return_value=None,
|
||||
)
|
||||
),
|
||||
)
|
||||
if len(test_results) == 0:
|
||||
logging.info(f"Test '{test_py_file_path}' failed to run, skipping it")
|
||||
if run_result is not None:
|
||||
logging.info(
|
||||
f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}"
|
||||
f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}",
|
||||
)
|
||||
return test_results
|
||||
|
||||
|
||||
def merge_test_results(
|
||||
xml_test_results: TestResults, bin_test_results: TestResults, test_framework: str
|
||||
xml_test_results: TestResults, bin_test_results: TestResults, test_framework: str,
|
||||
) -> TestResults:
|
||||
merged_test_results = TestResults()
|
||||
|
||||
|
|
@ -225,7 +225,7 @@ def merge_test_results(
|
|||
if test_framework == "unittest":
|
||||
test_function_name = result.id.test_function_name
|
||||
is_parameterized, new_test_function_name, _ = discover_parameters_unittest(
|
||||
test_function_name
|
||||
test_function_name,
|
||||
)
|
||||
if is_parameterized: # handle parameterized test
|
||||
test_function_name = new_test_function_name
|
||||
|
|
@ -266,7 +266,7 @@ def merge_test_results(
|
|||
did_pass=xml_result.did_pass,
|
||||
test_type=xml_result.test_type,
|
||||
return_value=result_bin.return_value,
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
for i in range(len(bin_results.test_results)):
|
||||
|
|
@ -290,7 +290,7 @@ def merge_test_results(
|
|||
did_pass=bin_result.did_pass,
|
||||
test_type=bin_result.test_type,
|
||||
return_value=bin_result.return_value,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return merged_test_results
|
||||
|
|
@ -321,10 +321,10 @@ def parse_test_results(
|
|||
test_file_path=test_py_path,
|
||||
)
|
||||
except AttributeError as e:
|
||||
logging.error(e)
|
||||
logging.exception(e)
|
||||
test_results_bin_file = TestResults()
|
||||
pathlib.Path(
|
||||
get_run_tmp_file(f"test_return_values_{optimization_iteration}.bin")
|
||||
get_run_tmp_file(f"test_return_values_{optimization_iteration}.bin"),
|
||||
).unlink(missing_ok=True)
|
||||
elif test_type == TestType.EXISTING_UNIT_TEST:
|
||||
try:
|
||||
|
|
@ -335,7 +335,7 @@ def parse_test_results(
|
|||
test_config=test_config,
|
||||
)
|
||||
except AttributeError as e:
|
||||
logging.error(e)
|
||||
logging.exception(e)
|
||||
test_results_bin_file = TestResults()
|
||||
else:
|
||||
raise ValueError(f"Invalid test type: {test_type}")
|
||||
|
|
@ -343,13 +343,13 @@ def parse_test_results(
|
|||
# We Probably want to remove deleting this file here later, because we want to preserve the reference to the
|
||||
# pickle blob in the test_results
|
||||
pathlib.Path(get_run_tmp_file(f"test_return_values_{optimization_iteration}.bin")).unlink(
|
||||
missing_ok=True
|
||||
missing_ok=True,
|
||||
)
|
||||
pathlib.Path(get_run_tmp_file(f"test_return_values_{optimization_iteration}.sqlite")).unlink(
|
||||
missing_ok=True
|
||||
missing_ok=True,
|
||||
)
|
||||
|
||||
merged_results = merge_test_results(
|
||||
test_results_xml, test_results_bin_file, test_config.test_framework
|
||||
test_results_xml, test_results_bin_file, test_config.test_framework,
|
||||
)
|
||||
return merged_results
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import subprocess
|
||||
from typing import Tuple, Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
|
||||
|
|
@ -28,7 +28,7 @@ def run_tests(
|
|||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=test_env,
|
||||
env=test_env, check=False,
|
||||
)
|
||||
elif test_framework == "unittest":
|
||||
result_file_path = get_run_tmp_file("unittest_results.xml")
|
||||
|
|
@ -40,7 +40,7 @@ def run_tests(
|
|||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=test_env,
|
||||
env=test_env, check=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid test framework -- I only support Pytest and Unittest currently.")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from pydantic.dataclasses import dataclass
|
|||
|
||||
|
||||
def get_test_file_path(
|
||||
test_dir: str, function_name: str, iteration: int = 0, test_type: str = "unit"
|
||||
test_dir: str, function_name: str, iteration: int = 0, test_type: str = "unit",
|
||||
) -> str:
|
||||
assert test_type in ["unit", "inspired", "replay"]
|
||||
function_name = function_name.replace(".", "_")
|
||||
|
|
@ -39,11 +39,9 @@ class ModifyInspiredTests(ast.NodeTransformer):
|
|||
|
||||
def visit_Import(self, node: ast.Import):
|
||||
self.import_list.append(node)
|
||||
return None
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom):
|
||||
self.import_list.append(node)
|
||||
return None
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef):
|
||||
if self.test_framework != "unittest":
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def generate_tests(
|
|||
instrumented_test_source = module.CACHED_INSTRUMENTED_TESTS
|
||||
path = get_run_tmp_file("").replace("\\", "\\\\") # Escape backslash for windows paths
|
||||
instrumented_test_source = instrumented_test_source.replace(
|
||||
"{codeflash_run_tmp_dir_client_side}", path
|
||||
"{codeflash_run_tmp_dir_client_side}", path,
|
||||
)
|
||||
logging.info(f"Using cached tests from {module_path}.CACHED_TESTS")
|
||||
else:
|
||||
|
|
@ -56,7 +56,7 @@ def generate_tests(
|
|||
generated_test_source, instrumented_test_source = response
|
||||
path = get_run_tmp_file("").replace("\\", "\\\\") # Escape backslash for windows paths
|
||||
instrumented_test_source = instrumented_test_source.replace(
|
||||
"{codeflash_run_tmp_dir_client_side}", path
|
||||
"{codeflash_run_tmp_dir_client_side}", path,
|
||||
)
|
||||
else:
|
||||
logging.error(f"Failed to generate and instrument tests for {function_to_optimize.function_name}")
|
||||
|
|
@ -77,7 +77,7 @@ def merge_unit_tests(unit_test_source: str, inspired_unit_tests: str, test_frame
|
|||
inspired_unit_tests_ast = ast.parse(inspired_unit_tests)
|
||||
unit_test_source_ast = ast.parse(unit_test_source)
|
||||
except SyntaxError as e:
|
||||
logging.error(f"Syntax error in code: {e}")
|
||||
logging.exception(f"Syntax error in code: {e}")
|
||||
return unit_test_source
|
||||
import_list: list[ast.stmt] = list()
|
||||
modified_ast = ModifyInspiredTests(import_list, test_framework).visit(inspired_unit_tests_ast)
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ def main():
|
|||
stdout,
|
||||
)
|
||||
num_unit_tests = int(unit_test_search.group(1))
|
||||
assert num_unit_tests > 0, f"Could not find existing unit tests"
|
||||
assert num_unit_tests > 0, "Could not find existing unit tests"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import datetime
|
|||
import decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.verification.comparator import comparator
|
||||
from codeflash.verification.equivalence import compare_results
|
||||
from codeflash.verification.test_results import (
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from codeflash.code_utils.formatter import format_code, sort_imports
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from codeflash.code_utils.formatter import format_code, sort_imports
|
||||
|
||||
|
||||
def test_remove_duplicate_imports():
|
||||
"""
|
||||
Test that duplicate imports are removed when should_sort_imports is True.
|
||||
"""Test that duplicate imports are removed when should_sort_imports is True.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||
tmp.write(b"import os\nimport os\n")
|
||||
|
|
@ -17,8 +17,7 @@ def test_remove_duplicate_imports():
|
|||
|
||||
|
||||
def test_remove_multiple_duplicate_imports():
|
||||
"""
|
||||
Test that multiple duplicate imports are removed when should_sort_imports is True.
|
||||
"""Test that multiple duplicate imports are removed when should_sort_imports is True.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||
tmp.write(b"import sys\nimport os\nimport sys\n")
|
||||
|
|
@ -30,8 +29,7 @@ def test_remove_multiple_duplicate_imports():
|
|||
|
||||
|
||||
def test_sorting_imports():
|
||||
"""
|
||||
Test that imports are sorted when should_sort_imports is True.
|
||||
"""Test that imports are sorted when should_sort_imports is True.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||
tmp.write(b"import sys\nimport unittest\nimport os\n")
|
||||
|
|
@ -43,8 +41,7 @@ def test_sorting_imports():
|
|||
|
||||
|
||||
def test_no_sorting_imports():
|
||||
"""
|
||||
Test that imports are not sorted when should_sort_imports is False.
|
||||
"""Test that imports are not sorted when should_sort_imports is False.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||
tmp.write(b"import sys\nimport unittest\nimport os\n")
|
||||
|
|
@ -56,15 +53,14 @@ def test_no_sorting_imports():
|
|||
|
||||
|
||||
def test_sort_imports_without_formatting():
|
||||
"""
|
||||
Test that imports are sorted when formatting is disabled and should_sort_imports is True.
|
||||
"""Test that imports are sorted when formatting is disabled and should_sort_imports is True.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||
tmp.write(b"import sys\nimport unittest\nimport os\n")
|
||||
tmp_path = tmp.name
|
||||
|
||||
new_code = format_code(
|
||||
formatter_cmd="disabled", imports_sort_cmd="isort", should_sort_imports=True, path=tmp_path
|
||||
formatter_cmd="disabled", imports_sort_cmd="isort", should_sort_imports=True, path=tmp_path,
|
||||
)
|
||||
os.remove(tmp_path)
|
||||
assert new_code == "import os\nimport sys\nimport unittest\n"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import pathlib
|
||||
from dataclasses import dataclass
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionParent
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
from codeflash.optimization.function_context import get_function_variables_definitions
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ def test_function_eligible_for_optimization() -> None:
|
|||
f.write(function)
|
||||
f.flush()
|
||||
functions_found = find_all_functions_in_file(f.name)
|
||||
assert "test_function_eligible_for_optimization" == functions_found[f.name][0].function_name
|
||||
assert functions_found[f.name][0].function_name == "test_function_eligible_for_optimization"
|
||||
|
||||
# Has no return statement
|
||||
function = """def test_function_not_eligible_for_optimization():
|
||||
|
|
|
|||
|
|
@ -1,18 +1,18 @@
|
|||
import ast
|
||||
import os.path
|
||||
import pathlib
|
||||
import pytest
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.config_consts import INDIVIDUAL_TEST_TIMEOUT
|
||||
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
|
||||
from codeflash.code_utils.instrument_existing_tests import InjectPerfOnly, inject_profiling_into_existing_test
|
||||
from codeflash.verification.parse_test_output import parse_test_results
|
||||
from codeflash.verification.test_results import TestType
|
||||
from codeflash.verification.test_runner import run_tests
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
from codeflash.code_utils.instrument_existing_tests import InjectPerfOnly
|
||||
|
||||
|
||||
def test_perfinjector_bubble_sort():
|
||||
|
|
@ -86,7 +86,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
f.write(code)
|
||||
f.flush()
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
f.name, "sorter", os.path.dirname(f.name)
|
||||
f.name, "sorter", os.path.dirname(f.name),
|
||||
)
|
||||
assert success
|
||||
assert new_test == expected.format(
|
||||
|
|
@ -161,7 +161,7 @@ def test_prepare_image_for_yolo():
|
|||
f.flush()
|
||||
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
f.name, "prepare_image_for_yolo", os.path.dirname(f.name)
|
||||
f.name, "prepare_image_for_yolo", os.path.dirname(f.name),
|
||||
)
|
||||
assert success
|
||||
assert new_test == expected.format(
|
||||
|
|
|
|||
|
|
@ -2,13 +2,11 @@ import os
|
|||
import unittest
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
from returns.result import Success, Failure
|
||||
|
||||
from codeflash.code_utils.shell_utils import save_api_key_to_rc, read_api_key_from_shell_config
|
||||
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc
|
||||
from returns.result import Failure, Success
|
||||
|
||||
|
||||
class TestShellUtils(unittest.TestCase):
|
||||
|
||||
@patch(
|
||||
"codeflash.code_utils.shell_utils.open",
|
||||
new_callable=mock_open,
|
||||
|
|
@ -20,7 +18,7 @@ class TestShellUtils(unittest.TestCase):
|
|||
api_key = "cf-12345"
|
||||
result = save_api_key_to_rc(api_key)
|
||||
self.assertTrue(isinstance(result, Success))
|
||||
mock_file.assert_called_with("/fake/path/.bashrc", "r", encoding="utf8")
|
||||
mock_file.assert_called_with("/fake/path/.bashrc", encoding="utf8")
|
||||
handle = mock_file()
|
||||
handle.write.assert_called_once()
|
||||
handle.truncate.assert_called_once()
|
||||
|
|
@ -42,7 +40,6 @@ class TestShellUtils(unittest.TestCase):
|
|||
|
||||
# unit tests
|
||||
class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Setup a temporary shell configuration file for testing."""
|
||||
self.test_rc_path = "test_shell_rc"
|
||||
|
|
@ -59,10 +56,11 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
with patch("codeflash.code_utils.shell_utils.get_shell_rc_path") as mock_get_shell_rc_path:
|
||||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch(
|
||||
"builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY="{self.api_key}"\n')
|
||||
"builtins.open",
|
||||
mock_open(read_data=f'export CODEFLASH_API_KEY="{self.api_key}"\n'),
|
||||
) as mock_file:
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
mock_file.assert_called_once_with(self.test_rc_path, "r", encoding="utf8")
|
||||
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
|
||||
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
def test_no_api_key(self, mock_get_shell_rc_path):
|
||||
|
|
@ -70,7 +68,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch("builtins.open", mock_open(read_data="# No API key here\n")) as mock_file:
|
||||
self.assertIsNone(read_api_key_from_shell_config())
|
||||
mock_file.assert_called_once_with(self.test_rc_path, "r", encoding="utf8")
|
||||
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
|
||||
|
||||
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
|
||||
def test_malformed_api_key_export(self, mock_get_shell_rc_path):
|
||||
|
|
@ -83,7 +81,8 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
result = read_api_key_from_shell_config()
|
||||
self.assertIsNone(result)
|
||||
with patch(
|
||||
"builtins.open", mock_open(read_data=f"export CODEFLASH_API_KEY=sk-{self.api_key}\n")
|
||||
"builtins.open",
|
||||
mock_open(read_data=f"export CODEFLASH_API_KEY=sk-{self.api_key}\n"),
|
||||
):
|
||||
result = read_api_key_from_shell_config()
|
||||
self.assertIsNone(result)
|
||||
|
|
@ -95,7 +94,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
with patch(
|
||||
"builtins.open",
|
||||
mock_open(
|
||||
read_data=f'export CODEFLASH_API_KEY="cf-firstkey"\nexport CODEFLASH_API_KEY="{self.api_key}"\n'
|
||||
read_data=f'export CODEFLASH_API_KEY="cf-firstkey"\nexport CODEFLASH_API_KEY="{self.api_key}"\n',
|
||||
),
|
||||
):
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
|
|
@ -107,7 +106,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
with patch(
|
||||
"builtins.open",
|
||||
mock_open(
|
||||
read_data=f'# Setting API Key\nexport CODEFLASH_API_KEY="{self.api_key}"\n# Done\n'
|
||||
read_data=f'# Setting API Key\nexport CODEFLASH_API_KEY="{self.api_key}"\n# Done\n',
|
||||
),
|
||||
):
|
||||
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
|
||||
|
|
@ -117,7 +116,8 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
|
|||
"""Test with API key export in a comment."""
|
||||
mock_get_shell_rc_path.return_value = self.test_rc_path
|
||||
with patch(
|
||||
"builtins.open", mock_open(read_data=f'# export CODEFLASH_API_KEY="{self.api_key}"\n')
|
||||
"builtins.open",
|
||||
mock_open(read_data=f'# export CODEFLASH_API_KEY="{self.api_key}"\n'),
|
||||
):
|
||||
self.assertIsNone(read_api_key_from_shell_config())
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class TestUnittestRunnerSorter(unittest.TestCase):
|
|||
"""
|
||||
cur_dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||
config = TestConfig(
|
||||
tests_root=cur_dir_path, project_root_path=cur_dir_path, test_framework="unittest"
|
||||
tests_root=cur_dir_path, project_root_path=cur_dir_path, test_framework="unittest",
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
|
||||
|
|
@ -56,7 +56,7 @@ def test_sort():
|
|||
"""
|
||||
cur_dir_path = os.path.dirname(os.path.abspath(__file__))
|
||||
config = TestConfig(
|
||||
tests_root=cur_dir_path, project_root_path=cur_dir_path, test_framework="pytest"
|
||||
tests_root=cur_dir_path, project_root_path=cur_dir_path, test_framework="pytest",
|
||||
)
|
||||
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
|
||||
fp.write(code.encode("utf-8"))
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ def test_unit_test_discovery_pytest():
|
|||
project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize"
|
||||
tests_path = project_path / "tests" / "pytest"
|
||||
test_config = TestConfig(
|
||||
tests_root=str(tests_path), project_root_path=str(project_path), test_framework="pytest"
|
||||
tests_root=str(tests_path), project_root_path=str(project_path), test_framework="pytest",
|
||||
)
|
||||
tests = discover_unit_tests(test_config)
|
||||
assert len(tests) > 0
|
||||
|
|
@ -21,7 +21,7 @@ def test_unit_test_discovery_unittest():
|
|||
project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize"
|
||||
test_path = project_path / "tests" / "unittest"
|
||||
test_config = TestConfig(
|
||||
tests_root=str(project_path), project_root_path=str(project_path), test_framework="unittest"
|
||||
tests_root=str(project_path), project_root_path=str(project_path), test_framework="unittest",
|
||||
)
|
||||
os.chdir(str(project_path))
|
||||
tests = discover_unit_tests(test_config)
|
||||
|
|
@ -47,7 +47,7 @@ def test_discover_tests_pytest_with_temp_dir_root():
|
|||
|
||||
# Create a TestConfig with the temporary directory as the root
|
||||
test_config = TestConfig(
|
||||
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest"
|
||||
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest",
|
||||
)
|
||||
|
||||
# Discover tests
|
||||
|
|
@ -110,7 +110,7 @@ def test_discover_tests_pytest_with_multi_level_dirs():
|
|||
|
||||
# Create a TestConfig with the temporary directory as the root
|
||||
test_config = TestConfig(
|
||||
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest"
|
||||
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest",
|
||||
)
|
||||
|
||||
# Discover tests
|
||||
|
|
@ -120,10 +120,10 @@ def test_discover_tests_pytest_with_multi_level_dirs():
|
|||
assert len(discovered_tests) == 3
|
||||
assert discovered_tests["root_code.root_function"][0].test_file == str(root_test_file_path)
|
||||
assert discovered_tests["level1_code.level1_function"][0].test_file == str(
|
||||
level1_test_file_path
|
||||
level1_test_file_path,
|
||||
)
|
||||
assert discovered_tests["level2_code.level2_function"][0].test_file == str(
|
||||
level2_test_file_path
|
||||
level2_test_file_path,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -194,7 +194,7 @@ def test_discover_tests_pytest_dirs():
|
|||
|
||||
# Create a TestConfig with the temporary directory as the root
|
||||
test_config = TestConfig(
|
||||
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest"
|
||||
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest",
|
||||
)
|
||||
|
||||
# Discover tests
|
||||
|
|
@ -204,13 +204,13 @@ def test_discover_tests_pytest_dirs():
|
|||
assert len(discovered_tests) == 4
|
||||
assert discovered_tests["root_code.root_function"][0].test_file == str(root_test_file_path)
|
||||
assert discovered_tests["level1_code.level1_function"][0].test_file == str(
|
||||
level1_test_file_path
|
||||
level1_test_file_path,
|
||||
)
|
||||
assert discovered_tests["level2_code.level2_function"][0].test_file == str(
|
||||
level2_test_file_path
|
||||
level2_test_file_path,
|
||||
)
|
||||
assert discovered_tests["level3_code.level3_function"][0].test_file == str(
|
||||
level3_test_file_path
|
||||
level3_test_file_path,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -219,7 +219,7 @@ def test_discover_tests_pytest_with_class():
|
|||
# Create a code file with a class
|
||||
code_file_path = pathlib.Path(tmpdirname) / "some_class_code.py"
|
||||
code_file_content = (
|
||||
"class SomeClass:\n" " def some_method(self):\n" " return True\n"
|
||||
"class SomeClass:\n def some_method(self):\n return True\n"
|
||||
)
|
||||
code_file_path.write_text(code_file_content)
|
||||
|
||||
|
|
@ -235,7 +235,7 @@ def test_discover_tests_pytest_with_class():
|
|||
|
||||
# Create a TestConfig with the temporary directory as the root
|
||||
test_config = TestConfig(
|
||||
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest"
|
||||
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest",
|
||||
)
|
||||
|
||||
# Discover tests
|
||||
|
|
@ -244,7 +244,7 @@ def test_discover_tests_pytest_with_class():
|
|||
# Check if the test class and method are discovered
|
||||
assert len(discovered_tests) == 1
|
||||
assert discovered_tests["some_class_code.SomeClass.some_method"][0].test_file == str(
|
||||
test_file_path
|
||||
test_file_path,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -277,7 +277,7 @@ def test_discover_tests_with_code_in_dir_and_test_in_subdir():
|
|||
|
||||
# Create a TestConfig with the code directory as the root
|
||||
test_config = TestConfig(
|
||||
tests_root=str(test_subdir), project_root_path=str(tmpdirname), test_framework="pytest"
|
||||
tests_root=str(test_subdir), project_root_path=str(tmpdirname), test_framework="pytest",
|
||||
)
|
||||
|
||||
# Discover tests
|
||||
|
|
|
|||
Loading…
Reference in a new issue