Make cli and django all RUFF and TUFF

This commit is contained in:
Saurabh Misra 2024-04-17 19:41:00 -07:00
parent 87bbc8c238
commit cf88e2b7d0
30 changed files with 243 additions and 257 deletions

View file

@ -19,10 +19,9 @@ else:
def make_cfapi_request(
endpoint: str, method: str, payload: Optional[Dict[str, Any]] = None
endpoint: str, method: str, payload: Optional[Dict[str, Any]] = None,
) -> requests.Response:
"""
Make an HTTP request using the specified method, URL, headers, and JSON payload.
"""Make an HTTP request using the specified method, URL, headers, and JSON payload.
:param endpoint: The endpoint URL to send the request to.
:param method: The HTTP method to use ('GET', 'POST', etc.).
:param payload: Optional JSON payload to include in the POST request body.
@ -41,8 +40,7 @@ def make_cfapi_request(
@lru_cache(maxsize=1)
def get_user_id() -> Optional[str]:
"""
Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
:return: The userid or None if the request fails.
"""
response = make_cfapi_request(endpoint="/cli-get-user", method="GET")
@ -50,7 +48,7 @@ def get_user_id() -> Optional[str]:
return response.text
else:
logging.error(
f"Failed to look up your userid; is your CF API key valid? ({response.reason})"
f"Failed to look up your userid; is your CF API key valid? ({response.reason})",
)
return None
@ -64,8 +62,7 @@ def suggest_changes(
existing_tests: str,
generated_tests: str,
) -> Response:
"""
Suggest changes to a pull request.
"""Suggest changes to a pull request.
Will make a review suggestion when possible;
or create a new dependent pull request with the suggested changes.
:param owner: The owner of the repository.
@ -98,8 +95,7 @@ def create_pr(
existing_tests: str,
generated_tests: str,
) -> Response:
"""
Create a pull request, targeting the specified branch. (usually 'main')
"""Create a pull request, targeting the specified branch. (usually 'main')
:param owner: The owner of the repository.
:param repo: The name of the repository.
:param base_branch: The base branch to target.
@ -122,8 +118,7 @@ def create_pr(
def check_github_app_installed_on_repo(owner: str, repo: str) -> Response:
"""
Check if the Codeflash GitHub App is installed on the specified repository.
"""Check if the Codeflash GitHub App is installed on the specified repository.
:param owner: The owner of the repository.
:param repo: The name of the repository.
:return: The response object.

View file

@ -1,18 +1,18 @@
import logging
import os
from argparse import Namespace, ArgumentParser, SUPPRESS
from argparse import SUPPRESS, ArgumentParser, Namespace
import git
from codeflash.api.cfapi import check_github_app_installed_on_repo
from codeflash.cli_cmds import logging_config
from codeflash.cli_cmds.cmd_init import init_codeflash, apologize_and_exit
from codeflash.cli_cmds.cmd_init import apologize_and_exit, init_codeflash
from codeflash.code_utils import env_utils
from codeflash.code_utils.compat import LF
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.git_utils import (
get_repo_owner_and_name,
get_github_secrets_page_url,
get_repo_owner_and_name,
)
from codeflash.version import __version__ as version
@ -87,7 +87,7 @@ def process_cmd_args(args: Namespace) -> Namespace:
try:
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
except ValueError as e:
logging.error(e.args[0])
logging.exception(e.args[0])
exit(1)
supported_keys = [
"module_root",
@ -108,10 +108,10 @@ def process_cmd_args(args: Namespace) -> Namespace:
) or not hasattr(args, key.replace("-", "_")):
setattr(args, key.replace("-", "_"), pyproject_config[key])
assert args.module_root is not None and os.path.isdir(
args.module_root
args.module_root,
), f"--module-root {args.module_root} must be a valid directory"
assert args.tests_root is not None and os.path.isdir(
args.tests_root
args.tests_root,
), f"--tests-root {args.tests_root} must be a valid directory"
assert not (
env_utils.get_pr_number() is not None and not env_utils.ensure_codeflash_api_key()
@ -127,7 +127,7 @@ def process_cmd_args(args: Namespace) -> Namespace:
normalized_ignore_paths = []
for path in args.ignore_paths:
assert os.path.exists(
path
path,
), f"ignore-paths config must be a valid path. Path {path} does not exist"
normalized_ignore_paths.append(os.path.realpath(path))
args.ignore_paths = normalized_ignore_paths
@ -150,9 +150,9 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
try:
git_repo = git.Repo(search_parent_directories=True)
except git.exc.InvalidGitRepositoryError:
logging.error(
logging.exception(
"I couldn't find a git repository in the current directory. "
"I need a git repository to run --all and open PRs for optimizations. Exiting..."
"I need a git repository to run --all and open PRs for optimizations. Exiting...",
)
apologize_and_exit()
owner, repo = get_repo_owner_and_name(git_repo)
@ -161,16 +161,16 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
if not response.ok or response.text != "true":
logging.error(f"Error: {response.text}")
raise Exception
except Exception as e:
logging.error(
except Exception:
logging.exception(
f"Could not find the Codeflash GitHub App installed on the repository {owner}/{repo} or the GitHub"
f" account linked to your CODEFLASH_API_KEY does not have access to the repository {owner}/{repo}.{LF}"
"Please install the Codeflash GitHub App on your repository to use --all. You can install it by going to "
f"https://github.com/settings/installations/{LF}"
f"https://github.com/settings/installations/{LF}",
)
apologize_and_exit()
if not hasattr(args, "all"):
setattr(args, "all", None)
args.all = None
elif args.all == "":
# The default behavior of --all is to optimize everything in args.module_root
args.all = args.module_root

View file

@ -4,7 +4,7 @@ import re
import subprocess
import sys
import time
from typing import Optional, NoReturn
from typing import NoReturn, Optional
import click
import inquirer
@ -20,8 +20,8 @@ from codeflash.code_utils.env_utils import (
)
from codeflash.code_utils.git_utils import get_github_secrets_page_url
from codeflash.code_utils.shell_utils import (
save_api_key_to_rc,
get_shell_rc_path,
save_api_key_to_rc,
)
from codeflash.telemetry.posthog import ph
from codeflash.version import __version__ as version
@ -69,11 +69,11 @@ def init_codeflash() -> None:
f" codeflash --all to optimize all functions in all files in the module you selected ({setup_info.module_root}){LF}"
# f" codeflash --pr <pr-number> to optimize a PR{LF}"
f"-or-{LF}"
f" codeflash --help to see all options{LF}"
f" codeflash --help to see all options{LF}",
)
if did_add_new_key:
click.echo(
"🐚 Don't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
"🐚 Don't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!",
)
click.echo("Or run the following command to reload:")
if os.name == "nt":
@ -101,7 +101,7 @@ def collect_setup_info() -> SetupInfo:
# Check if the cwd is writable
if not os.access(curdir, os.W_OK):
click.echo(
f"❌ The current directory isn't writable, please check your folder permissions and try again.{LF}"
f"❌ The current directory isn't writable, please check your folder permissions and try again.{LF}",
)
click.echo("It's likely you don't have write permissions for this folder.")
sys.exit(1)
@ -132,7 +132,7 @@ def collect_setup_info() -> SetupInfo:
module_subdir_options = valid_module_subdirs + [curdir_option]
module_root_answer = inquirer.list_input(
message=f"Which Python module do you want me to optimize going forward? (Usually the top-most directory with all of your Python source code)",
message="Which Python module do you want me to optimize going forward? (Usually the top-most directory with all of your Python source code)",
choices=module_subdir_options,
default=(
project_name if project_name in module_subdir_options else module_subdir_options[0]
@ -172,7 +172,7 @@ def collect_setup_info() -> SetupInfo:
exists=True,
normalize_to_absolute_path=True,
),
]
],
)
tests_root = (
custom_tests_root_answer["path"] if custom_tests_root_answer else apologize_and_exit()
@ -224,7 +224,7 @@ def detect_test_framework(curdir, tests_root) -> Optional[str]:
for pytest_file in pytest_files:
file_path = os.path.join(curdir, pytest_file)
if os.path.exists(file_path):
with open(file_path, "r", encoding="utf8") as file:
with open(file_path, encoding="utf8") as file:
contents = file.read()
if pytest_config_patterns[pytest_file] in contents:
test_framework = "pytest"
@ -234,7 +234,7 @@ def detect_test_framework(curdir, tests_root) -> Optional[str]:
# Check if any python files contain a class that inherits from unittest.TestCase
for filename in os.listdir(tests_root):
if filename.endswith(".py"):
with open(os.path.join(tests_root, filename), "r", encoding="utf8") as file:
with open(os.path.join(tests_root, filename), encoding="utf8") as file:
contents = file.read()
try:
node = ast.parse(contents)
@ -265,38 +265,38 @@ def check_for_toml_or_setup_file() -> Optional[str]:
project_name = None
if os.path.exists(pyproject_toml_path):
try:
with open(pyproject_toml_path, "r", encoding="utf8") as f:
with open(pyproject_toml_path, encoding="utf8") as f:
pyproject_toml_content = f.read()
project_name = tomlkit.parse(pyproject_toml_content)["tool"]["poetry"]["name"]
click.echo(f"✅ I found a pyproject.toml for your project {project_name}.")
ph("cli-pyproject-toml-found-name")
except Exception as e:
click.echo(f"✅ I found a pyproject.toml for your project.")
except Exception:
click.echo("✅ I found a pyproject.toml for your project.")
ph("cli-pyproject-toml-found")
else:
if os.path.exists(setup_py_path):
with open(setup_py_path, "r", encoding="utf8") as f:
with open(setup_py_path, encoding="utf8") as f:
setup_py_content = f.read()
project_name_match = re.search(
r"setup\s*\([^)]*?name\s*=\s*['\"](.*?)['\"]", setup_py_content, re.DOTALL
r"setup\s*\([^)]*?name\s*=\s*['\"](.*?)['\"]", setup_py_content, re.DOTALL,
)
if project_name_match:
project_name = project_name_match.group(1)
click.echo(f"✅ Found setup.py for your project {project_name}")
ph("cli-setup-py-found-name")
else:
click.echo(f"✅ Found setup.py.")
click.echo("✅ Found setup.py.")
ph("cli-setup-py-found")
click.echo(
f"💡 I couldn't find a pyproject.toml in the current directory ({curdir}).{LF}"
f"(make sure you're running `codeflash init` from your project's root directory!){LF}"
f"I need this file to store my configuration settings."
f"I need this file to store my configuration settings.",
)
ph("cli-no-pyproject-toml-or-setup-py")
# Create a pyproject.toml file because it doesn't exist
create_toml = inquirer.confirm(
f"Do you want me to create a pyproject.toml file in the current directory?",
"Do you want me to create a pyproject.toml file in the current directory?",
default=True,
show_default=False,
)
@ -314,9 +314,9 @@ def check_for_toml_or_setup_file() -> Optional[str]:
click.echo(f"✅ Created a pyproject.toml file at {pyproject_toml_path}")
click.pause()
ph("cli-created-pyproject-toml")
except IOError as e:
except OSError:
click.echo(
"❌ Failed to create pyproject.toml. Please check your disk permissions and available space."
"❌ Failed to create pyproject.toml. Please check your disk permissions and available space.",
)
apologize_and_exit()
else:
@ -328,7 +328,7 @@ def check_for_toml_or_setup_file() -> Optional[str]:
def apologize_and_exit() -> NoReturn:
click.echo(
"💡 If you're having trouble, see https://app.codeflash.ai/app/getting-started for further help getting started with Codeflash!"
"💡 If you're having trouble, see https://app.codeflash.ai/app/getting-started for further help getting started with Codeflash!",
)
click.echo("👋 Exiting...")
sys.exit(1)
@ -363,10 +363,10 @@ def prompt_github_action(setup_info: SetupInfo) -> None:
python_version_string = f" {py_version.major}.{py_version.minor}"
optimize_yml_content = read_text(
"codeflash.cli_cmds.workflows", "codeflash-optimize.yaml"
"codeflash.cli_cmds.workflows", "codeflash-optimize.yaml",
)
optimize_yml_content = optimize_yml_content.replace(
" {{ python_version }}", python_version_string
" {{ python_version }}", python_version_string,
)
with open(optimize_yaml_path, "w", encoding="utf8") as optimize_yml_file:
optimize_yml_file.write(optimize_yml_content)
@ -400,12 +400,12 @@ def prompt_github_action(setup_info: SetupInfo) -> None:
click.launch(optimize_yaml_path)
click.echo(
"📝 I opened the workflow file in your editor! You'll need to edit the steps that install the right Python "
+ f"version and any project dependencies. See the comments in the file for more details.{LF}"
+ f"version and any project dependencies. See the comments in the file for more details.{LF}",
)
click.pause()
click.echo()
click.echo(
f"🚀 Codeflash is now configured to automatically optimize new Github PRs!{LF}"
f"🚀 Codeflash is now configured to automatically optimize new Github PRs!{LF}",
)
ph("cli-github-workflow-created")
else:
@ -417,18 +417,18 @@ def prompt_github_action(setup_info: SetupInfo) -> None:
def configure_pyproject_toml(setup_info: SetupInfo) -> None:
toml_path = os.path.join(os.getcwd(), "pyproject.toml")
try:
with open(toml_path, "r", encoding="utf8") as pyproject_file:
with open(toml_path, encoding="utf8") as pyproject_file:
pyproject_data = tomlkit.parse(pyproject_file.read())
except FileNotFoundError:
click.echo(
f"I couldn't find a pyproject.toml in the current directory.{LF}"
f"Please create a new empty pyproject.toml file here, OR if you use poetry then run `poetry init`, OR run `codeflash init` again from a directory with an existing pyproject.toml file."
f"Please create a new empty pyproject.toml file here, OR if you use poetry then run `poetry init`, OR run `codeflash init` again from a directory with an existing pyproject.toml file.",
)
apologize_and_exit()
codeflash_section = tomlkit.table()
codeflash_section.add(
tomlkit.comment("All paths are relative to this pyproject.toml's directory.")
tomlkit.comment("All paths are relative to this pyproject.toml's directory."),
)
codeflash_section["module-root"] = setup_info.module_root
codeflash_section["tests-root"] = setup_info.tests_root
@ -440,7 +440,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
tool_section["codeflash"] = codeflash_section
pyproject_data["tool"] = tool_section
click.echo(f"Writing Codeflash configuration ...\r", nl=False)
click.echo("Writing Codeflash configuration ...\r", nl=False)
with open(toml_path, "w", encoding="utf8") as pyproject_file:
pyproject_file.write(tomlkit.dumps(pyproject_data))
click.echo(f"✅ Added Codeflash configuration to {toml_path}")
@ -466,7 +466,7 @@ class CFAPIKeyType(click.ParamType):
def prompt_api_key() -> bool:
try:
existing_api_key = get_codeflash_api_key()
except EnvironmentError:
except OSError:
existing_api_key = None
if existing_api_key:
display_key = f"{existing_api_key[:3]}****{existing_api_key[-4:]}"
@ -499,14 +499,13 @@ def enter_api_key_and_save_to_rc() -> None:
).strip()
if api_key:
break
else:
if not browser_launched:
click.echo(
f"Opening your Codeflash API key page. Grab a key from there!{LF}"
"You can also open this link manually: https://app.codeflash.ai/app/apikeys"
)
click.launch("https://app.codeflash.ai/app/apikeys")
browser_launched = True # This does not work on remote consoles
elif not browser_launched:
click.echo(
f"Opening your Codeflash API key page. Grab a key from there!{LF}"
"You can also open this link manually: https://app.codeflash.ai/app/apikeys",
)
click.launch("https://app.codeflash.ai/app/apikeys")
browser_launched = True # This does not work on remote consoles
shell_rc_path = get_shell_rc_path()
if not shell_rc_path.exists() and os.name == "nt":
# On Windows, create a batch file in the user's home directory (not auto-run, just used to store api key)
@ -581,5 +580,5 @@ def run_end_to_end_test(setup_info: SetupInfo) -> None:
click.echo(f"{LF}✅ End-to-end test passed. Codeflash has been correctly set up!")
else:
click.echo(
f"{LF}❌ End-to-end test failed. Please check the logs above, and take a look at https://app.codeflash.ai/app/getting-started for help and troubleshooting."
f"{LF}❌ End-to-end test failed. Please check the logs above, and take a look at https://app.codeflash.ai/app/getting-started for help and troubleshooting.",
)

View file

@ -11,25 +11,25 @@ def find_pyproject_toml(config_file=None):
if config_file is not None:
if not config_file.lower().endswith(".toml"):
raise ValueError(
f"Config file {config_file} is not a valid toml file. Please recheck the path to pyproject.toml"
f"Config file {config_file} is not a valid toml file. Please recheck the path to pyproject.toml",
)
if not os.path.exists(config_file):
raise ValueError(
f"Config file {config_file} does not exist. Please recheck the path to pyproject.toml"
f"Config file {config_file} does not exist. Please recheck the path to pyproject.toml",
)
return config_file
else:
dir_path = os.getcwd()
while not os.path.dirname(dir_path) == dir_path:
while os.path.dirname(dir_path) != dir_path:
config_file = os.path.join(dir_path, "pyproject.toml")
if os.path.exists(config_file):
return config_file
# Search for pyproject.toml in the parent directories
dir_path = os.path.dirname(dir_path)
raise ValueError(
f"Could not find pyproject.toml in the current directory {os.getcwd()} or any of the parent directories. Please create it by running `poetry init`, or pass the path to pyproject.toml with the --config-file argument."
f"Could not find pyproject.toml in the current directory {os.getcwd()} or any of the parent directories. Please create it by running `poetry init`, or pass the path to pyproject.toml with the --config-file argument.",
)
@ -40,7 +40,7 @@ def parse_config_file(config_file_path=None):
data = tomlkit.parse(f.read())
except tomlkit.exceptions.ParseError as e:
raise ValueError(
f"Error while parsing the config file {config_file_path}. Please recheck the file for syntax errors. Error: {e}"
f"Error while parsing the config file {config_file_path}. Please recheck the file for syntax errors. Error: {e}",
)
try:
@ -50,7 +50,7 @@ def parse_config_file(config_file_path=None):
except tomlkit.exceptions.NonExistentKey:
raise ValueError(
f"Could not find the 'codeflash' block in the config file {config_file_path}. "
f"Please run 'codeflash init' to create the config file."
f"Please run 'codeflash init' to create the config file.",
)
assert isinstance(config, dict)
@ -88,7 +88,7 @@ def parse_config_file(config_file_path=None):
for key in path_keys:
if key in config:
config[key] = os.path.realpath(
os.path.join(os.path.dirname(config_file_path), config[key])
os.path.join(os.path.dirname(config_file_path), config[key]),
)
for key in path_list_keys:

View file

@ -14,16 +14,16 @@ from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
def get_codeflash_api_key() -> Optional[str]:
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
if not api_key:
raise EnvironmentError(
raise OSError(
"I didn't find a Codeflash API key in your environment.\n"
+ "You can generate one at https://app.codeflash.ai/app/apikeys,\n"
+ "then set it as a CODEFLASH_API_KEY environment variable."
+ "then set it as a CODEFLASH_API_KEY environment variable.",
)
if not api_key.startswith("cf-"):
raise EnvironmentError(
raise OSError(
f"Your Codeflash API key seems to be invalid. It should start with a 'cf-' prefix; I found '{api_key}' instead.\n"
+ "You can generate one at https://app.codeflash.ai/app/apikeys,\n"
+ "then set it as a CODEFLASH_API_KEY environment variable."
+ "then set it as a CODEFLASH_API_KEY environment variable.",
)
return api_key
@ -31,11 +31,11 @@ def get_codeflash_api_key() -> Optional[str]:
def ensure_codeflash_api_key() -> bool:
try:
get_codeflash_api_key()
except EnvironmentError:
logging.error(
except OSError:
logging.exception(
"Codeflash API key not found in your environment.\n"
+ "You can generate one at https://app.codeflash.ai/app/apikeys,\n"
+ "then set it as a CODEFLASH_API_KEY environment variable."
+ "then set it as a CODEFLASH_API_KEY environment variable.",
)
return False
return True
@ -58,8 +58,8 @@ def get_pr_number() -> Optional[int]:
def ensure_pr_number() -> bool:
if not get_pr_number():
raise EnvironmentError(
"CODEFLASH_PR_NUMBER not found in environment variables; make sure the Github Action is setting this so Codeflash can comment on the right PR"
raise OSError(
"CODEFLASH_PR_NUMBER not found in environment variables; make sure the Github Action is setting this so Codeflash can comment on the right PR",
)
return True

View file

@ -1,18 +1,19 @@
import logging
import os.path
import subprocess
import isort
def format_code(
formatter_cmd: str, imports_sort_cmd: str, should_sort_imports: bool, path: str
formatter_cmd: str, imports_sort_cmd: str, should_sort_imports: bool, path: str,
) -> str:
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
if formatter_cmd.lower() == "disabled":
if should_sort_imports:
return sort_imports(imports_sort_cmd, should_sort_imports, path)
with open(path, "r", encoding="utf8") as f:
with open(path, encoding="utf8") as f:
new_code = f.read()
return new_code
@ -22,7 +23,7 @@ def format_code(
# the main problem is custom config parsing https://github.com/psf/black/issues/779
assert os.path.exists(path), f"File {path} does not exist. Cannot format the file. Exiting..."
result = subprocess.run(
formatter_cmd_list + [path], stdout=subprocess.PIPE, stderr=subprocess.PIPE
formatter_cmd_list + [path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False,
)
if result.returncode == 0:
logging.info("OK")
@ -32,7 +33,7 @@ def format_code(
if should_sort_imports:
return sort_imports(imports_sort_cmd, should_sort_imports, path)
with open(path, "r", encoding="utf8") as f:
with open(path, encoding="utf8") as f:
new_code = f.read()
return new_code
@ -40,7 +41,7 @@ def format_code(
def sort_imports(imports_sort_cmd: str, should_sort_imports: bool, path: str) -> str:
if imports_sort_cmd.lower() == "disabled":
with open(path, "r", encoding="utf8") as f:
with open(path, encoding="utf8") as f:
code = f.read()
return code
@ -48,11 +49,11 @@ def sort_imports(imports_sort_cmd: str, should_sort_imports: bool, path: str) ->
# Deduplicate and sort imports
isort.file(path)
with open(path, "r", encoding="utf8") as f:
with open(path, encoding="utf8") as f:
new_code = f.read()
return new_code
else:
# Return original code
with open(path, "r", encoding="utf8") as f:
with open(path, encoding="utf8") as f:
code = f.read()
return code

View file

@ -9,13 +9,13 @@ from unidiff import PatchSet
def get_git_diff(
repo_directory: str = os.getcwd(), uncommitted_changes: bool = False
repo_directory: str = os.getcwd(), uncommitted_changes: bool = False,
) -> dict[str, list[int]]:
repository = git.Repo(repo_directory, search_parent_directories=True)
commit = repository.head.commit
if uncommitted_changes:
uni_diff_text = repository.git.diff(
None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True
None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True,
)
else:
uni_diff_text = repository.git.diff(
@ -54,8 +54,7 @@ def get_git_diff(
def get_current_branch(repo: Optional[Repo] = None) -> str:
"""
Returns the name of the current branch in the given repository.
"""Returns the name of the current branch in the given repository.
:param repo: An optional Repo object. If not provided, the function will
search for a repository in the current and parent directories.

View file

@ -2,7 +2,7 @@ import ast
from _ast import ClassDef
from typing import Any, Optional, Tuple
from codeflash.code_utils.code_utils import module_name_from_file_path, get_run_tmp_file
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
class ReplaceCallNodeWithName(ast.NodeTransformer):
@ -26,7 +26,7 @@ class InjectPerfOnly(ast.NodeTransformer):
self.module_path = module_path
def update_line_node(
self, test_node, node_name, index: str, test_class_name: Optional[str] = None
self, test_node, node_name, index: str, test_class_name: Optional[str] = None,
):
call_node = None
for node in ast.walk(test_node):
@ -122,17 +122,17 @@ class InjectPerfOnly(ast.NodeTransformer):
ast.JoinedStr(
values=[
ast.Constant(
value=f"{get_run_tmp_file('test_return_values_')}"
value=f"{get_run_tmp_file('test_return_values_')}",
),
ast.FormattedValue(
value=ast.Name(
id="codeflash_iteration", ctx=ast.Load()
id="codeflash_iteration", ctx=ast.Load(),
),
conversion=-1,
),
ast.Constant(value=".sqlite"),
]
)
],
),
],
keywords=[],
),
@ -164,8 +164,8 @@ class InjectPerfOnly(ast.NodeTransformer):
ast.Constant(
value="CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT,"
" test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT,"
" iteration_id TEXT, runtime INTEGER, return_value BLOB)"
)
" iteration_id TEXT, runtime INTEGER, return_value BLOB)",
),
],
keywords=[],
),
@ -184,8 +184,8 @@ class InjectPerfOnly(ast.NodeTransformer):
),
args=[],
keywords=[],
)
)
),
),
]
)
i = len(node.body) - 1
@ -200,15 +200,14 @@ class InjectPerfOnly(ast.NodeTransformer):
for internal_node in ast.walk(compound_line_node):
if self.is_target_function_line(internal_node):
line_node.body[j : j + 1] = self.update_line_node(
internal_node, node.name, str(i) + "_" + str(j), test_class_name
internal_node, node.name, str(i) + "_" + str(j), test_class_name,
)
break
j -= 1
else:
if self.is_target_function_line(line_node):
node.body[i : i + 1] = self.update_line_node(
line_node, node.name, str(i), test_class_name
)
elif self.is_target_function_line(line_node):
node.body[i : i + 1] = self.update_line_node(
line_node, node.name, str(i), test_class_name,
)
i -= 1
return node
@ -216,7 +215,8 @@ class InjectPerfOnly(ast.NodeTransformer):
class FunctionImportedAsVisitor(ast.NodeVisitor):
"""This checks if a function has been imported as an alias. We only care about the alias then.
from numpy import array as np_array
np_array is what we want"""
np_array is what we want
"""
def __init__(self, original_function_name):
self.original_function_name = original_function_name
@ -226,12 +226,12 @@ class FunctionImportedAsVisitor(ast.NodeVisitor):
def visit_ImportFrom(self, node: ast.ImportFrom):
for alias in node.names:
if alias.name == self.original_function_name:
if hasattr(alias, "asname") and not alias.asname is None:
if hasattr(alias, "asname") and alias.asname is not None:
self.imported_as_function_name = alias.asname
def inject_profiling_into_existing_test(test_path, function_name, root_path) -> Tuple[bool, str]:
with open(test_path, "r", encoding="utf8") as f:
with open(test_path, encoding="utf8") as f:
test_code = f.read()
try:
tree = ast.parse(test_code)
@ -284,21 +284,21 @@ def create_wrapper_function(function_name, module_path):
value=ast.JoinedStr(
values=[
ast.FormattedValue(
value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1
value=ast.Name(id="test_module_name", ctx=ast.Load()), conversion=-1,
),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.Name(id="test_class_name", ctx=ast.Load()), conversion=-1
value=ast.Name(id="test_class_name", ctx=ast.Load()), conversion=-1,
),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1
value=ast.Name(id="test_name", ctx=ast.Load()), conversion=-1,
),
ast.Constant(value=":"),
ast.FormattedValue(
value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1
value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1,
),
]
],
),
lineno=lineno + 1,
),
@ -321,11 +321,11 @@ def create_wrapper_function(function_name, module_path):
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()),
attr="index",
ctx=ast.Store(),
)
),
],
value=ast.Dict(keys=[], values=[]),
lineno=lineno + 3,
)
),
],
orelse=[],
lineno=lineno + 2,
@ -339,7 +339,7 @@ def create_wrapper_function(function_name, module_path):
value=ast.Name(id="codeflash_wrap", ctx=ast.Load()),
attr="index",
ctx=ast.Load(),
)
),
],
),
body=[
@ -356,7 +356,7 @@ def create_wrapper_function(function_name, module_path):
op=ast.Add(),
value=ast.Constant(value=1),
lineno=lineno + 5,
)
),
],
orelse=[
ast.Assign(
@ -369,11 +369,11 @@ def create_wrapper_function(function_name, module_path):
),
slice=ast.Name(id="test_id", ctx=ast.Load()),
ctx=ast.Store(),
)
),
],
value=ast.Constant(value=0),
lineno=lineno + 6,
)
),
],
lineno=lineno + 4,
),
@ -397,13 +397,13 @@ def create_wrapper_function(function_name, module_path):
value=ast.JoinedStr(
values=[
ast.FormattedValue(
value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1
value=ast.Name(id="line_id", ctx=ast.Load()), conversion=-1,
),
ast.Constant(value="_"),
ast.FormattedValue(
value=ast.Name(id="codeflash_test_index", ctx=ast.Load()), conversion=-1
value=ast.Name(id="codeflash_test_index", ctx=ast.Load()), conversion=-1,
),
]
],
),
lineno=lineno + 8,
),

View file

@ -9,16 +9,16 @@ from codeflash.code_utils.compat import LF
if os.name == "nt": # Windows
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.M)
SHELL_RC_EXPORT_PREFIX = f"set CODEFLASH_API_KEY="
SHELL_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
else:
SHELL_RC_EXPORT_PATTERN = re.compile(r'^export CODEFLASH_API_KEY="?(cf-[^\s"]+)"?$', re.M)
SHELL_RC_EXPORT_PREFIX = f"export CODEFLASH_API_KEY="
SHELL_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
def read_api_key_from_shell_config() -> Optional[str]:
try:
shell_rc_path = get_shell_rc_path()
with open(shell_rc_path, "r", encoding="utf8") as shell_rc:
with open(shell_rc_path, encoding="utf8") as shell_rc:
shell_contents = shell_rc.read()
matches = SHELL_RC_EXPORT_PATTERN.findall(shell_contents)
return matches[-1] if matches else None
@ -39,7 +39,7 @@ def get_shell_rc_path() -> Path:
"tcsh": ".cshrc",
"dash": ".profile",
}.get(
shell, ".bashrc"
shell, ".bashrc",
) # map each shell to its config file and default to .bashrc
return Path.home() / shell_rc_filename
@ -62,7 +62,7 @@ def save_api_key_to_rc(api_key) -> Result[str, str]:
if existing_api_key:
# Replace the existing API key line
updated_shell_contents = re.sub(
SHELL_RC_EXPORT_PATTERN, api_key_line, shell_contents
SHELL_RC_EXPORT_PATTERN, api_key_line, shell_contents,
)
action = "Updated CODEFLASH_API_KEY in"
else:
@ -77,7 +77,7 @@ def save_api_key_to_rc(api_key) -> Result[str, str]:
except PermissionError:
return Failure(
f"💡 I tried adding your Codeflash API key to {shell_rc_path} - but seems like I don't have permissions to do so.{LF}"
f"You'll need to open it yourself and add the following line:{LF}{LF}{api_key_line}{LF}"
f"You'll need to open it yourself and add the following line:{LF}{LF}{api_key_line}{LF}",
)
except FileNotFoundError:
return Failure(

View file

@ -30,7 +30,7 @@ class TestsInFile:
def from_pytest_stdout_line_co(cls, module: str, function: str, directory: str):
absolute_test_path = os.path.join(directory, module)
assert os.path.exists(
absolute_test_path
absolute_test_path,
), f"Test discovery failed - Test file does not exist {absolute_test_path}"
return cls(
test_file=absolute_test_path,
@ -44,7 +44,7 @@ class TestsInFile:
parts = line.split("::")
absolute_test_path = os.path.join(pytest_rootdir, parts[0])
assert os.path.exists(
absolute_test_path
absolute_test_path,
), f"Test discovery failed - Test file does not exist {absolute_test_path}"
if len(parts) == 3:
return cls(
@ -121,7 +121,7 @@ def discover_tests_pytest(cfg: TestConfig) -> Dict[str, List[TestsInFile]]:
pytest_result = subprocess.run(
pytest_cmd_list + [f"{tests_root}", "--co", "-q", "-m", "not skip"],
stdout=subprocess.PIPE,
cwd=project_root,
cwd=project_root, check=False,
)
pytest_stdout = pytest_result.stdout.decode("utf-8")
@ -185,7 +185,7 @@ def discover_tests_unittest(cfg: TestConfig) -> Dict[str, List[TestsInFile]]:
{
"test_function": details.test_function,
"test_suite_name": details.test_suite_name,
}
},
)
else:
details = get_test_details(test)
@ -194,7 +194,7 @@ def discover_tests_unittest(cfg: TestConfig) -> Dict[str, List[TestsInFile]]:
{
"test_function": details.test_function,
"test_suite_name": details.test_suite_name,
}
},
)
return process_test_files(file_to_test_map, cfg)
@ -208,7 +208,7 @@ def discover_parameters_unittest(function_name: str):
def process_test_files(
file_to_test_map: Dict[str, List[Dict[str, str]]], cfg: TestConfig
file_to_test_map: Dict[str, List[Dict[str, str]]], cfg: TestConfig,
) -> Dict[str, List[TestsInFile]]:
project_root_path = cfg.project_root_path
test_framework = cfg.test_framework
@ -231,10 +231,9 @@ def process_test_files(
parameters = re.split(r"\[|\]", function)[1]
if name.name == function_name and name.type == "function":
test_functions.add(TestFunction(name.name, None, parameters))
else:
if name.name == function and name.type == "function":
test_functions.add(TestFunction(name.name, None, None))
break
elif name.name == function and name.type == "function":
test_functions.add(TestFunction(name.name, None, None))
break
if test_framework == "unittest":
functions_to_search = [elem["test_function"] for elem in functions]
test_suites = [elem["test_suite_name"] for elem in functions]
@ -254,7 +253,7 @@ def process_test_files(
if is_parameterized and new_function == def_name.name:
test_functions.add(
TestFunction(def_name.name, name.name, parameters)
TestFunction(def_name.name, name.name, parameters),
)
elif function == def_name.name:
test_functions.add(TestFunction(def_name.name, name.name, None))
@ -282,7 +281,7 @@ def process_test_files(
follow_builtin_imports=False,
)
except Exception as e:
logging.error(str(e))
logging.exception(str(e))
continue
if definition and definition[0].type == "function":
definition_path = str(definition[0].module_path)
@ -298,7 +297,7 @@ def process_test_files(
scope_test_function += "_" + scope_parameters
function_to_test_map[definition[0].full_name].append(
TestsInFile(test_file, None, scope_test_function, scope_test_suite)
TestsInFile(test_file, None, scope_test_function, scope_test_suite),
)
deduped_function_to_test_map = {}
for function, tests in function_to_test_map.items():
@ -307,7 +306,7 @@ def process_test_files(
def parse_pytest_stdout(
pytest_stdout: str, pytest_rootdir: str, tests_root: str, parse_type: ParseType
pytest_stdout: str, pytest_rootdir: str, tests_root: str, parse_type: ParseType,
) -> List[TestsInFile]:
test_results = []
if parse_type == ParseType.Q:
@ -384,7 +383,7 @@ def parse_pytest_stdout(
function = function.group(1)
try:
test_result = TestsInFile.from_pytest_stdout_line_co(
module, function, directory
module, function, directory,
)
test_results.append(test_result)
except ValueError as e:

View file

@ -51,7 +51,7 @@ class FunctionVisitor(cst.CSTVisitor):
parents=list(reversed(ast_parents)),
starting_line=pos.start.line,
ending_line=pos.end.line,
)
),
)
@ -66,8 +66,8 @@ class FunctionWithReturnStatement(ast.NodeVisitor):
if function_has_return_statement(node):
self.functions.append(
FunctionToOptimize(
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:]
)
function_name=node.name, file_path=self.file_path, parents=self.ast_path[:],
),
)
# Continue visiting the body of the function to find nested functions
self.generic_visit(node)
@ -145,14 +145,14 @@ def get_functions_to_optimize_by_file(
if found_function is None:
raise ValueError(
f"Function {only_function_name} not found in file {file} or"
f" the function does not have a 'return' statement."
f" the function does not have a 'return' statement.",
)
functions[file] = [found_function]
else:
logging.info("Finding all functions modified in the current git diff ...")
functions = get_functions_within_git_diff()
filtered_modified_functions, functions_count = filter_functions(
functions, test_cfg.tests_root, ignore_paths, project_root, module_root
functions, test_cfg.tests_root, ignore_paths, project_root, module_root,
)
logging.info("Found %d functions to optimize", functions_count)
return filtered_modified_functions, functions_count
@ -164,12 +164,12 @@ def get_functions_within_git_diff() -> Dict[str, List[FunctionToOptimize]]:
for path in modified_lines:
if not os.path.exists(path):
continue
with open(path, "r", encoding="utf8") as f:
with open(path, encoding="utf8") as f:
file_content = f.read()
try:
wrapper = cst.metadata.MetadataWrapper(cst.parse_module(file_content))
except Exception as e:
logging.error(e)
logging.exception(e)
continue
function_lines = FunctionVisitor(file_path=path)
wrapper.visit(function_lines)
@ -203,11 +203,11 @@ def get_all_files_and_functions(module_root_path: str) -> Dict[str, List[Functio
def find_all_functions_in_file(file_path: str) -> Dict[str, List[FunctionToOptimize]]:
functions: Dict[str, List[FunctionToOptimize]] = {}
with open(file_path, "r", encoding="utf8") as f:
with open(file_path, encoding="utf8") as f:
try:
ast_module = ast.parse(f.read())
except Exception as e:
logging.error(e)
logging.exception(e)
return functions
function_name_visitor = FunctionWithReturnStatement(file_path)
function_name_visitor.visit(ast_module)

View file

@ -1,9 +1,10 @@
from typing import Union
from codeflash.verification.test_results import TestResults
from pydantic import BaseModel
from pydantic.dataclasses import dataclass
from codeflash.verification.test_results import TestResults
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
class PrComment:

View file

@ -1,5 +1,5 @@
import logging
from typing import Dict, Any, Optional
from typing import Any, Dict, Optional
from posthog import Posthog
@ -10,8 +10,7 @@ _posthog = None
def initialize_posthog(enabled: bool) -> None:
"""
Enable or disable PostHog.
"""Enable or disable PostHog.
:param enabled: Whether to enable PostHog.
"""
if not enabled:
@ -26,8 +25,7 @@ def initialize_posthog(enabled: bool) -> None:
def ph(event: str, properties: Optional[Dict[str, Any]] = None) -> None:
"""
Log an event to PostHog.
"""Log an event to PostHog.
:param event: The name of the event.
:param properties: A dictionary of properties to attach to the event.
"""

View file

@ -6,7 +6,7 @@ from typing import Any, Dict, Generator, List, Tuple
def get_next_arg_and_return(
trace_file: str, function_name: str, num_to_get: int = 3
trace_file: str, function_name: str, num_to_get: int = 3,
) -> Generator[Tuple[Any, Any], None, None]:
db = sqlite3.connect(trace_file)
cur = db.cursor()
@ -27,14 +27,14 @@ def get_next_arg_and_return(
matched_arg_return[frame_address].append(val[7])
if len(matched_arg_return[frame_address]) > 1:
logging.warning(
f"Pre-existing call to the function {function_name} with same frame address."
f"Pre-existing call to the function {function_name} with same frame address.",
)
elif event_type == "return":
matched_arg_return[frame_address].append(val[6])
arg_return_length = len(matched_arg_return[frame_address])
if arg_return_length > 2:
logging.warning(
f"Pre-existing return to the function {function_name} with same frame address."
f"Pre-existing return to the function {function_name} with same frame address.",
)
elif arg_return_length == 1:
logging.warning(f"No call before return for {function_name}!")
@ -51,7 +51,7 @@ def get_function_alias(module: str, function_name: str) -> str:
def create_trace_replay_test(
trace_file: str, functions: List[Tuple[str, str]], test_framework: str = "pytest"
trace_file: str, functions: List[Tuple[str, str]], test_framework: str = "pytest",
) -> str:
assert test_framework in ["pytest", "unittest"]
@ -84,7 +84,7 @@ def _create_unittest_trace_replay_test(trace_file: str, functions: List[Tuple[st
return_val = pickle.loads(return_val_pkl)
ret = {function_name}(**args)
self.assertTrue(comparator(return_val, ret))
"""
""",
)
test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n"
@ -111,7 +111,7 @@ def _create_pytest_trace_replay_test(trace_file: str, functions: List[Tuple[str,
return_val = pickle.loads(return_val_pkl)
ret = {function_name}(**args)
assert comparator(return_val, ret)
"""
""",
)
test_template = ""

View file

@ -15,8 +15,7 @@ from codeflash.verification.verification_utils import get_test_file_path
class Tracer:
"""
Use this class as a 'with' context manager to trace a function call,
"""Use this class as a 'with' context manager to trace a function call,
input arguments, and return value.
"""
@ -57,7 +56,7 @@ class Tracer:
# TODO: Check out if we need to export the function test name as well
cur.execute(
"CREATE TABLE events(type TEXT, function TEXT, filename TEXT, line_number INTEGER, "
"last_frame_address INTEGER, time_ns INTEGER, arg BLOB, locals BLOB)"
"last_frame_address INTEGER, time_ns INTEGER, arg BLOB, locals BLOB)",
)
sys.setprofile(self.trace_callback)
@ -80,23 +79,23 @@ class Tracer:
)
function_path = "_".join([func for _, func in module_function])
test_file_path = get_test_file_path(
self.config["tests_root"], function_path, test_type="replay"
self.config["tests_root"], function_path, test_type="replay",
)
with open(test_file_path, "w", encoding="utf8") as file:
file.write(replay_test)
logging.info(
f"Codeflash: Function Traced successfully and replay test created! Path - {test_file_path}"
f"Codeflash: Function Traced successfully and replay test created! Path - {test_file_path}",
)
def trace_callback(self, frame: Any, event: str, arg: Any) -> None:
if event not in ["call", "return"]:
return None
return
code = frame.f_code
if self.functions:
if code.co_name not in self.functions:
return None
return
if self.function_count[code.co_name] >= self.max_function_count:
return
self.function_count[code.co_name] += 1
@ -114,7 +113,7 @@ class Tracer:
project_root = os.path.realpath(os.path.join(self.config["module_root"], ".."))
self.function_modules[code.co_name] = module_name_from_file_path(
code.co_filename, project_root=project_root
code.co_filename, project_root=project_root,
)
cur = self.con.cursor()

View file

@ -16,7 +16,7 @@ def main():
version_replacement = r"\g<1>" + major_minor_version + r".x"
# Read the LICENSE file
with open("codeflash/LICENSE", "r", encoding="utf8") as file:
with open("codeflash/LICENSE", encoding="utf8") as file:
license_text = file.read()
# Replace the version in the LICENSE file

View file

@ -114,7 +114,7 @@ def comparator(orig: Any, new: Any) -> bool:
return orig.equals(new)
if HAS_PANDAS and isinstance(
orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period)
orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period),
):
return orig == new
@ -145,6 +145,6 @@ def comparator(orig: Any, new: Any) -> bool:
logging.warning(f"Unknown comparator input type: {type(orig)}")
return True
except RecursionError as e:
logging.error(f"RecursionError while comparing objects: {e}")
logging.exception(f"RecursionError while comparing objects: {e}")
sentry_sdk.capture_exception(e)
return False

View file

@ -11,21 +11,21 @@ import sentry_sdk
from junitparser.xunit2 import JUnitXml
from codeflash.code_utils.code_utils import (
module_name_from_file_path,
get_run_tmp_file,
module_name_from_file_path,
)
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
from codeflash.verification.test_results import (
TestResults,
FunctionTestInvocation,
TestType,
InvocationId,
TestResults,
TestType,
)
from codeflash.verification.verification_utils import TestConfig
def parse_test_return_values_bin(
file_location: str, test_framework: str, test_type: TestType, test_file_path: str
file_location: str, test_framework: str, test_type: TestType, test_file_path: str,
) -> TestResults:
test_results = TestResults()
if not os.path.exists(file_location):
@ -47,7 +47,7 @@ def parse_test_return_values_bin(
try:
test_pickle = pickle.loads(file.read(len_next))
except Exception as e:
logging.error(f"Failed to load pickle file. Exception: {e}")
logging.exception(f"Failed to load pickle file. Exception: {e}")
return test_results
len_next = file.read(4)
len_next = int.from_bytes(len_next, byteorder="big")
@ -65,7 +65,7 @@ def parse_test_return_values_bin(
test_framework=test_framework,
test_type=test_type,
return_value=test_pickle,
)
),
)
# Hardcoding the test result to True because the test did execute and we are only interested in the return values,
# the did_pass comes from the xml results file
@ -104,7 +104,7 @@ def parse_sqlite_test_results(
test_framework=test_config.test_framework,
test_type=test_type,
return_value=None,
)
),
)
# return_value is only None temporarily as this is only being used for the existing tests. This should generalize
# to read the return_value from the sqlite file as well.
@ -136,7 +136,7 @@ def parse_test_xml(
for testcase in suite:
class_name = testcase.classname
file_name = suite._elem.attrib.get(
"file"
"file",
) # file_path_from_module_name(generated_tests_path, test_config.project_root_path)
if (
file_name == f"unittest{os.sep}loader.py"
@ -145,10 +145,10 @@ def parse_test_xml(
and suite.tests == 1
):
# This means that the test failed to load, so we don't want to crash on it
logging.info(f"Test failed to load, skipping it.")
logging.info("Test failed to load, skipping it.")
if run_result is not None:
logging.info(
f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}"
f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}",
)
return test_results
file_name = test_py_file_path
@ -170,7 +170,7 @@ def parse_test_xml(
xml_file_contents = open(test_xml_file_path).read()
scope.set_extra("file", xml_file_contents)
sentry_sdk.capture_message(
f"testcase.name is None in parse_test_xml for testcase {repr(testcase)} in file {xml_file_contents}"
f"testcase.name is None in parse_test_xml for testcase {testcase!r} in file {xml_file_contents}",
)
continue
# Parse test timing
@ -193,19 +193,19 @@ def parse_test_xml(
did_pass=result,
test_type=test_type,
return_value=None,
)
),
)
if len(test_results) == 0:
logging.info(f"Test '{test_py_file_path}' failed to run, skipping it")
if run_result is not None:
logging.info(
f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}"
f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}",
)
return test_results
def merge_test_results(
xml_test_results: TestResults, bin_test_results: TestResults, test_framework: str
xml_test_results: TestResults, bin_test_results: TestResults, test_framework: str,
) -> TestResults:
merged_test_results = TestResults()
@ -225,7 +225,7 @@ def merge_test_results(
if test_framework == "unittest":
test_function_name = result.id.test_function_name
is_parameterized, new_test_function_name, _ = discover_parameters_unittest(
test_function_name
test_function_name,
)
if is_parameterized: # handle parameterized test
test_function_name = new_test_function_name
@ -266,7 +266,7 @@ def merge_test_results(
did_pass=xml_result.did_pass,
test_type=xml_result.test_type,
return_value=result_bin.return_value,
)
),
)
else:
for i in range(len(bin_results.test_results)):
@ -290,7 +290,7 @@ def merge_test_results(
did_pass=bin_result.did_pass,
test_type=bin_result.test_type,
return_value=bin_result.return_value,
)
),
)
return merged_test_results
@ -321,10 +321,10 @@ def parse_test_results(
test_file_path=test_py_path,
)
except AttributeError as e:
logging.error(e)
logging.exception(e)
test_results_bin_file = TestResults()
pathlib.Path(
get_run_tmp_file(f"test_return_values_{optimization_iteration}.bin")
get_run_tmp_file(f"test_return_values_{optimization_iteration}.bin"),
).unlink(missing_ok=True)
elif test_type == TestType.EXISTING_UNIT_TEST:
try:
@ -335,7 +335,7 @@ def parse_test_results(
test_config=test_config,
)
except AttributeError as e:
logging.error(e)
logging.exception(e)
test_results_bin_file = TestResults()
else:
raise ValueError(f"Invalid test type: {test_type}")
@ -343,13 +343,13 @@ def parse_test_results(
# We Probably want to remove deleting this file here later, because we want to preserve the reference to the
# pickle blob in the test_results
pathlib.Path(get_run_tmp_file(f"test_return_values_{optimization_iteration}.bin")).unlink(
missing_ok=True
missing_ok=True,
)
pathlib.Path(get_run_tmp_file(f"test_return_values_{optimization_iteration}.sqlite")).unlink(
missing_ok=True
missing_ok=True,
)
merged_results = merge_test_results(
test_results_xml, test_results_bin_file, test_config.test_framework
test_results_xml, test_results_bin_file, test_config.test_framework,
)
return merged_results

View file

@ -1,5 +1,5 @@
import subprocess
from typing import Tuple, Optional
from typing import Optional, Tuple
from codeflash.code_utils.code_utils import get_run_tmp_file
@ -28,7 +28,7 @@ def run_tests(
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=cwd,
env=test_env,
env=test_env, check=False,
)
elif test_framework == "unittest":
result_file_path = get_run_tmp_file("unittest_results.xml")
@ -40,7 +40,7 @@ def run_tests(
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=cwd,
env=test_env,
env=test_env, check=False,
)
else:
raise ValueError("Invalid test framework -- I only support Pytest and Unittest currently.")

View file

@ -5,7 +5,7 @@ from pydantic.dataclasses import dataclass
def get_test_file_path(
test_dir: str, function_name: str, iteration: int = 0, test_type: str = "unit"
test_dir: str, function_name: str, iteration: int = 0, test_type: str = "unit",
) -> str:
assert test_type in ["unit", "inspired", "replay"]
function_name = function_name.replace(".", "_")
@ -39,11 +39,9 @@ class ModifyInspiredTests(ast.NodeTransformer):
def visit_Import(self, node: ast.Import):
self.import_list.append(node)
return None
def visit_ImportFrom(self, node: ast.ImportFrom):
self.import_list.append(node)
return None
def visit_ClassDef(self, node: ast.ClassDef):
if self.test_framework != "unittest":

View file

@ -34,7 +34,7 @@ def generate_tests(
instrumented_test_source = module.CACHED_INSTRUMENTED_TESTS
path = get_run_tmp_file("").replace("\\", "\\\\") # Escape backslash for windows paths
instrumented_test_source = instrumented_test_source.replace(
"{codeflash_run_tmp_dir_client_side}", path
"{codeflash_run_tmp_dir_client_side}", path,
)
logging.info(f"Using cached tests from {module_path}.CACHED_TESTS")
else:
@ -56,7 +56,7 @@ def generate_tests(
generated_test_source, instrumented_test_source = response
path = get_run_tmp_file("").replace("\\", "\\\\") # Escape backslash for windows paths
instrumented_test_source = instrumented_test_source.replace(
"{codeflash_run_tmp_dir_client_side}", path
"{codeflash_run_tmp_dir_client_side}", path,
)
else:
logging.error(f"Failed to generate and instrument tests for {function_to_optimize.function_name}")
@ -77,7 +77,7 @@ def merge_unit_tests(unit_test_source: str, inspired_unit_tests: str, test_frame
inspired_unit_tests_ast = ast.parse(inspired_unit_tests)
unit_test_source_ast = ast.parse(unit_test_source)
except SyntaxError as e:
logging.error(f"Syntax error in code: {e}")
logging.exception(f"Syntax error in code: {e}")
return unit_test_source
import_list: list[ast.stmt] = list()
modified_ast = ModifyInspiredTests(import_list, test_framework).visit(inspired_unit_tests_ast)

View file

@ -60,7 +60,7 @@ def main():
stdout,
)
num_unit_tests = int(unit_test_search.group(1))
assert num_unit_tests > 0, f"Could not find existing unit tests"
assert num_unit_tests > 0, "Could not find existing unit tests"
if __name__ == "__main__":

View file

@ -2,6 +2,7 @@ import datetime
import decimal
import pytest
from codeflash.verification.comparator import comparator
from codeflash.verification.equivalence import compare_results
from codeflash.verification.test_results import (

View file

@ -1,11 +1,11 @@
from codeflash.code_utils.formatter import format_code, sort_imports
import os
import tempfile
from codeflash.code_utils.formatter import format_code, sort_imports
def test_remove_duplicate_imports():
"""
Test that duplicate imports are removed when should_sort_imports is True.
"""Test that duplicate imports are removed when should_sort_imports is True.
"""
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(b"import os\nimport os\n")
@ -17,8 +17,7 @@ def test_remove_duplicate_imports():
def test_remove_multiple_duplicate_imports():
"""
Test that multiple duplicate imports are removed when should_sort_imports is True.
"""Test that multiple duplicate imports are removed when should_sort_imports is True.
"""
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(b"import sys\nimport os\nimport sys\n")
@ -30,8 +29,7 @@ def test_remove_multiple_duplicate_imports():
def test_sorting_imports():
"""
Test that imports are sorted when should_sort_imports is True.
"""Test that imports are sorted when should_sort_imports is True.
"""
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(b"import sys\nimport unittest\nimport os\n")
@ -43,8 +41,7 @@ def test_sorting_imports():
def test_no_sorting_imports():
"""
Test that imports are not sorted when should_sort_imports is False.
"""Test that imports are not sorted when should_sort_imports is False.
"""
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(b"import sys\nimport unittest\nimport os\n")
@ -56,15 +53,14 @@ def test_no_sorting_imports():
def test_sort_imports_without_formatting():
"""
Test that imports are sorted when formatting is disabled and should_sort_imports is True.
"""Test that imports are sorted when formatting is disabled and should_sort_imports is True.
"""
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(b"import sys\nimport unittest\nimport os\n")
tmp_path = tmp.name
new_code = format_code(
formatter_cmd="disabled", imports_sort_cmd="isort", should_sort_imports=True, path=tmp_path
formatter_cmd="disabled", imports_sort_cmd="isort", should_sort_imports=True, path=tmp_path,
)
os.remove(tmp_path)
assert new_code == "import os\nimport sys\nimport unittest\n"

View file

@ -1,7 +1,7 @@
import pathlib
from dataclasses import dataclass
from codeflash.discovery.functions_to_optimize import FunctionToOptimize, FunctionParent
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
from codeflash.optimization.function_context import get_function_variables_definitions

View file

@ -15,7 +15,7 @@ def test_function_eligible_for_optimization() -> None:
f.write(function)
f.flush()
functions_found = find_all_functions_in_file(f.name)
assert "test_function_eligible_for_optimization" == functions_found[f.name][0].function_name
assert functions_found[f.name][0].function_name == "test_function_eligible_for_optimization"
# Has no return statement
function = """def test_function_not_eligible_for_optimization():

View file

@ -1,18 +1,18 @@
import ast
import os.path
import pathlib
import pytest
import sys
import tempfile
import pytest
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.config_consts import INDIVIDUAL_TEST_TIMEOUT
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.code_utils.instrument_existing_tests import InjectPerfOnly, inject_profiling_into_existing_test
from codeflash.verification.parse_test_output import parse_test_results
from codeflash.verification.test_results import TestType
from codeflash.verification.test_runner import run_tests
from codeflash.verification.verification_utils import TestConfig
from codeflash.code_utils.instrument_existing_tests import InjectPerfOnly
def test_perfinjector_bubble_sort():
@ -86,7 +86,7 @@ class TestPigLatin(unittest.TestCase):
f.write(code)
f.flush()
success, new_test = inject_profiling_into_existing_test(
f.name, "sorter", os.path.dirname(f.name)
f.name, "sorter", os.path.dirname(f.name),
)
assert success
assert new_test == expected.format(
@ -161,7 +161,7 @@ def test_prepare_image_for_yolo():
f.flush()
success, new_test = inject_profiling_into_existing_test(
f.name, "prepare_image_for_yolo", os.path.dirname(f.name)
f.name, "prepare_image_for_yolo", os.path.dirname(f.name),
)
assert success
assert new_test == expected.format(

View file

@ -2,13 +2,11 @@ import os
import unittest
from unittest.mock import mock_open, patch
from returns.result import Success, Failure
from codeflash.code_utils.shell_utils import save_api_key_to_rc, read_api_key_from_shell_config
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc
from returns.result import Failure, Success
class TestShellUtils(unittest.TestCase):
@patch(
"codeflash.code_utils.shell_utils.open",
new_callable=mock_open,
@ -20,7 +18,7 @@ class TestShellUtils(unittest.TestCase):
api_key = "cf-12345"
result = save_api_key_to_rc(api_key)
self.assertTrue(isinstance(result, Success))
mock_file.assert_called_with("/fake/path/.bashrc", "r", encoding="utf8")
mock_file.assert_called_with("/fake/path/.bashrc", encoding="utf8")
handle = mock_file()
handle.write.assert_called_once()
handle.truncate.assert_called_once()
@ -42,7 +40,6 @@ class TestShellUtils(unittest.TestCase):
# unit tests
class TestReadApiKeyFromShellConfig(unittest.TestCase):
def setUp(self):
"""Setup a temporary shell configuration file for testing."""
self.test_rc_path = "test_shell_rc"
@ -59,10 +56,11 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
with patch("codeflash.code_utils.shell_utils.get_shell_rc_path") as mock_get_shell_rc_path:
mock_get_shell_rc_path.return_value = self.test_rc_path
with patch(
"builtins.open", mock_open(read_data=f'export CODEFLASH_API_KEY="{self.api_key}"\n')
"builtins.open",
mock_open(read_data=f'export CODEFLASH_API_KEY="{self.api_key}"\n'),
) as mock_file:
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
mock_file.assert_called_once_with(self.test_rc_path, "r", encoding="utf8")
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
def test_no_api_key(self, mock_get_shell_rc_path):
@ -70,7 +68,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
mock_get_shell_rc_path.return_value = self.test_rc_path
with patch("builtins.open", mock_open(read_data="# No API key here\n")) as mock_file:
self.assertIsNone(read_api_key_from_shell_config())
mock_file.assert_called_once_with(self.test_rc_path, "r", encoding="utf8")
mock_file.assert_called_once_with(self.test_rc_path, encoding="utf8")
@patch("codeflash.code_utils.shell_utils.get_shell_rc_path")
def test_malformed_api_key_export(self, mock_get_shell_rc_path):
@ -83,7 +81,8 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
with patch(
"builtins.open", mock_open(read_data=f"export CODEFLASH_API_KEY=sk-{self.api_key}\n")
"builtins.open",
mock_open(read_data=f"export CODEFLASH_API_KEY=sk-{self.api_key}\n"),
):
result = read_api_key_from_shell_config()
self.assertIsNone(result)
@ -95,7 +94,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
with patch(
"builtins.open",
mock_open(
read_data=f'export CODEFLASH_API_KEY="cf-firstkey"\nexport CODEFLASH_API_KEY="{self.api_key}"\n'
read_data=f'export CODEFLASH_API_KEY="cf-firstkey"\nexport CODEFLASH_API_KEY="{self.api_key}"\n',
),
):
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
@ -107,7 +106,7 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
with patch(
"builtins.open",
mock_open(
read_data=f'# Setting API Key\nexport CODEFLASH_API_KEY="{self.api_key}"\n# Done\n'
read_data=f'# Setting API Key\nexport CODEFLASH_API_KEY="{self.api_key}"\n# Done\n',
),
):
self.assertEqual(read_api_key_from_shell_config(), self.api_key)
@ -117,7 +116,8 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
"""Test with API key export in a comment."""
mock_get_shell_rc_path.return_value = self.test_rc_path
with patch(
"builtins.open", mock_open(read_data=f'# export CODEFLASH_API_KEY="{self.api_key}"\n')
"builtins.open",
mock_open(read_data=f'# export CODEFLASH_API_KEY="{self.api_key}"\n'),
):
self.assertIsNone(read_api_key_from_shell_config())

View file

@ -27,7 +27,7 @@ class TestUnittestRunnerSorter(unittest.TestCase):
"""
cur_dir_path = os.path.dirname(os.path.abspath(__file__))
config = TestConfig(
tests_root=cur_dir_path, project_root_path=cur_dir_path, test_framework="unittest"
tests_root=cur_dir_path, project_root_path=cur_dir_path, test_framework="unittest",
)
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
@ -56,7 +56,7 @@ def test_sort():
"""
cur_dir_path = os.path.dirname(os.path.abspath(__file__))
config = TestConfig(
tests_root=cur_dir_path, project_root_path=cur_dir_path, test_framework="pytest"
tests_root=cur_dir_path, project_root_path=cur_dir_path, test_framework="pytest",
)
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
fp.write(code.encode("utf-8"))

View file

@ -10,7 +10,7 @@ def test_unit_test_discovery_pytest():
project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize"
tests_path = project_path / "tests" / "pytest"
test_config = TestConfig(
tests_root=str(tests_path), project_root_path=str(project_path), test_framework="pytest"
tests_root=str(tests_path), project_root_path=str(project_path), test_framework="pytest",
)
tests = discover_unit_tests(test_config)
assert len(tests) > 0
@ -21,7 +21,7 @@ def test_unit_test_discovery_unittest():
project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize"
test_path = project_path / "tests" / "unittest"
test_config = TestConfig(
tests_root=str(project_path), project_root_path=str(project_path), test_framework="unittest"
tests_root=str(project_path), project_root_path=str(project_path), test_framework="unittest",
)
os.chdir(str(project_path))
tests = discover_unit_tests(test_config)
@ -47,7 +47,7 @@ def test_discover_tests_pytest_with_temp_dir_root():
# Create a TestConfig with the temporary directory as the root
test_config = TestConfig(
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest"
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest",
)
# Discover tests
@ -110,7 +110,7 @@ def test_discover_tests_pytest_with_multi_level_dirs():
# Create a TestConfig with the temporary directory as the root
test_config = TestConfig(
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest"
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest",
)
# Discover tests
@ -120,10 +120,10 @@ def test_discover_tests_pytest_with_multi_level_dirs():
assert len(discovered_tests) == 3
assert discovered_tests["root_code.root_function"][0].test_file == str(root_test_file_path)
assert discovered_tests["level1_code.level1_function"][0].test_file == str(
level1_test_file_path
level1_test_file_path,
)
assert discovered_tests["level2_code.level2_function"][0].test_file == str(
level2_test_file_path
level2_test_file_path,
)
@ -194,7 +194,7 @@ def test_discover_tests_pytest_dirs():
# Create a TestConfig with the temporary directory as the root
test_config = TestConfig(
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest"
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest",
)
# Discover tests
@ -204,13 +204,13 @@ def test_discover_tests_pytest_dirs():
assert len(discovered_tests) == 4
assert discovered_tests["root_code.root_function"][0].test_file == str(root_test_file_path)
assert discovered_tests["level1_code.level1_function"][0].test_file == str(
level1_test_file_path
level1_test_file_path,
)
assert discovered_tests["level2_code.level2_function"][0].test_file == str(
level2_test_file_path
level2_test_file_path,
)
assert discovered_tests["level3_code.level3_function"][0].test_file == str(
level3_test_file_path
level3_test_file_path,
)
@ -219,7 +219,7 @@ def test_discover_tests_pytest_with_class():
# Create a code file with a class
code_file_path = pathlib.Path(tmpdirname) / "some_class_code.py"
code_file_content = (
"class SomeClass:\n" " def some_method(self):\n" " return True\n"
"class SomeClass:\n def some_method(self):\n return True\n"
)
code_file_path.write_text(code_file_content)
@ -235,7 +235,7 @@ def test_discover_tests_pytest_with_class():
# Create a TestConfig with the temporary directory as the root
test_config = TestConfig(
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest"
tests_root=str(tmpdirname), project_root_path=str(tmpdirname), test_framework="pytest",
)
# Discover tests
@ -244,7 +244,7 @@ def test_discover_tests_pytest_with_class():
# Check if the test class and method are discovered
assert len(discovered_tests) == 1
assert discovered_tests["some_class_code.SomeClass.some_method"][0].test_file == str(
test_file_path
test_file_path,
)
@ -277,7 +277,7 @@ def test_discover_tests_with_code_in_dir_and_test_in_subdir():
# Create a TestConfig with the code directory as the root
test_config = TestConfig(
tests_root=str(test_subdir), project_root_path=str(tmpdirname), test_framework="pytest"
tests_root=str(test_subdir), project_root_path=str(tmpdirname), test_framework="pytest",
)
# Discover tests