mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
round 1
This commit is contained in:
parent
83d2c0b385
commit
5cd94cdf64
18 changed files with 470 additions and 391 deletions
|
|
@ -1,13 +1,14 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
from argparse import SUPPRESS, ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
|
||||
import git
|
||||
|
||||
from codeflash.cli_cmds import logging_config
|
||||
from codeflash.cli_cmds.cli_common import apologize_and_exit
|
||||
from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
from codeflash.code_utils.git_utils import (
|
||||
|
|
@ -17,7 +18,6 @@ from codeflash.code_utils.git_utils import (
|
|||
get_repo_owner_and_name,
|
||||
)
|
||||
from codeflash.code_utils.github_utils import get_github_secrets_page_url, require_github_app_or_exit
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.version import __version__ as version
|
||||
|
||||
|
||||
|
|
@ -104,18 +104,18 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
|
|||
logger.error("If you specify a --function, you must specify the --file it is in")
|
||||
sys.exit(1)
|
||||
if args.file:
|
||||
if not os.path.exists(args.file):
|
||||
if not Path(args.file).exists():
|
||||
logger.error(f"File {args.file} does not exist")
|
||||
sys.exit(1)
|
||||
args.file = os.path.realpath(args.file)
|
||||
args.file = Path(args.file).resolve()
|
||||
if not args.no_pr:
|
||||
owner, repo = get_repo_owner_and_name()
|
||||
require_github_app_or_exit(owner, repo)
|
||||
if args.replay_test:
|
||||
if not os.path.isfile(args.replay_test):
|
||||
if not Path(args.replay_test).is_file():
|
||||
logger.error(f"Replay test file {args.replay_test} does not exist")
|
||||
sys.exit(1)
|
||||
args.replay_test = os.path.realpath(args.replay_test)
|
||||
args.replay_test = Path(args.replay_test).resolve()
|
||||
|
||||
return args
|
||||
|
||||
|
|
@ -142,11 +142,11 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
|||
or not hasattr(args, key.replace("-", "_"))
|
||||
):
|
||||
setattr(args, key.replace("-", "_"), pyproject_config[key])
|
||||
assert args.module_root is not None and os.path.isdir(
|
||||
args.module_root,
|
||||
assert (
|
||||
args.module_root is not None and Path(args.module_root).is_dir()
|
||||
), f"--module-root {args.module_root} must be a valid directory"
|
||||
assert args.tests_root is not None and os.path.isdir(
|
||||
args.tests_root,
|
||||
assert (
|
||||
args.tests_root is not None and Path(args.tests_root).is_dir()
|
||||
), f"--tests-root {args.tests_root} must be a valid directory"
|
||||
|
||||
assert not (env_utils.get_pr_number() is not None and not env_utils.ensure_codeflash_api_key()), (
|
||||
|
|
@ -160,24 +160,23 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
|||
if hasattr(args, "ignore_paths") and args.ignore_paths is not None:
|
||||
normalized_ignore_paths = []
|
||||
for path in args.ignore_paths:
|
||||
assert os.path.exists(
|
||||
path,
|
||||
), f"ignore-paths config must be a valid path. Path {path} does not exist"
|
||||
normalized_ignore_paths.append(os.path.realpath(path))
|
||||
path_obj = Path(path)
|
||||
assert path_obj.exists(), f"ignore-paths config must be a valid path. Path {path} does not exist"
|
||||
normalized_ignore_paths.append(path_obj.resolve())
|
||||
args.ignore_paths = normalized_ignore_paths
|
||||
# Project root path is one level above the specified directory, because that's where the module can be imported from
|
||||
args.module_root = os.path.realpath(args.module_root)
|
||||
args.module_root = Path(args.module_root).resolve()
|
||||
# If module-root is "." then all imports are relatives to it.
|
||||
# in this case, the ".." becomes outside project scope, causing issues with un-importable paths
|
||||
args.project_root = project_root_from_module_root(args.module_root, pyproject_file_path)
|
||||
args.tests_root = os.path.realpath(args.tests_root)
|
||||
args.tests_root = Path(args.tests_root).resolve()
|
||||
return handle_optimize_all_arg_parsing(args)
|
||||
|
||||
|
||||
def project_root_from_module_root(module_root: str, pyproject_file_path: str) -> str:
|
||||
if os.path.dirname(pyproject_file_path) == module_root:
|
||||
def project_root_from_module_root(module_root: Path, pyproject_file_path: Path) -> Path:
|
||||
if pyproject_file_path.parent == module_root:
|
||||
return module_root
|
||||
return os.path.realpath(os.path.join(module_root, ".."))
|
||||
return module_root.parent.resolve()
|
||||
|
||||
|
||||
def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
|
||||
|
|
@ -203,5 +202,5 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
|
|||
# The default behavior of --all is to optimize everything in args.module_root
|
||||
args.all = args.module_root
|
||||
else:
|
||||
args.all = os.path.realpath(args.all)
|
||||
args.all = Path(args.all).resolve()
|
||||
return args
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ from __future__ import annotations
|
|||
|
||||
import ast
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
|
@ -107,7 +107,7 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
|
|||
|
||||
|
||||
def collect_setup_info() -> SetupInfo:
|
||||
curdir = os.getcwd()
|
||||
curdir = Path.cwd()
|
||||
# Check if the cwd is writable
|
||||
if not os.access(curdir, os.W_OK):
|
||||
click.echo(
|
||||
|
|
@ -170,20 +170,22 @@ def collect_setup_info() -> SetupInfo:
|
|||
)
|
||||
|
||||
if tests_root_answer == create_for_me_option:
|
||||
tests_root = os.path.join(curdir, default_tests_subdir)
|
||||
os.mkdir(tests_root)
|
||||
click.echo(f"✅ Created directory {tests_root}{os.pathsep}{LF}")
|
||||
tests_root = Path(curdir) / default_tests_subdir
|
||||
tests_root.mkdir()
|
||||
click.echo(f"✅ Created directory {tests_root}{os.path.sep}{LF}")
|
||||
elif tests_root_answer == custom_dir_option:
|
||||
custom_tests_root_answer = inquirer_wrapper_path(
|
||||
"path",
|
||||
message=f"Enter the path to your tests directory inside {os.path.abspath('.') + os.path.sep} ",
|
||||
message=f"Enter the path to your tests directory inside {Path(curdir).resolve()}{os.path.sep} ",
|
||||
path_type=inquirer.Path.DIRECTORY,
|
||||
exists=True,
|
||||
)
|
||||
tests_root = custom_tests_root_answer["path"] if custom_tests_root_answer else apologize_and_exit()
|
||||
tests_root = (
|
||||
Path(custom_tests_root_answer["path"]) if custom_tests_root_answer else apologize_and_exit()
|
||||
)
|
||||
else:
|
||||
tests_root = tests_root_answer
|
||||
tests_root = os.path.relpath(tests_root, curdir)
|
||||
tests_root = Path(tests_root_answer)
|
||||
tests_root = tests_root.relative_to(curdir)
|
||||
ph("cli-tests-root-provided")
|
||||
|
||||
# Autodiscover test framework
|
||||
|
|
@ -224,7 +226,7 @@ def collect_setup_info() -> SetupInfo:
|
|||
)
|
||||
|
||||
|
||||
def detect_test_framework(curdir, tests_root) -> Optional[str]:
|
||||
def detect_test_framework(curdir: Path, tests_root: Path) -> Optional[str]:
|
||||
test_framework = None
|
||||
pytest_files = ["pytest.ini", "pyproject.toml", "tox.ini", "setup.cfg"]
|
||||
pytest_config_patterns = {
|
||||
|
|
@ -234,9 +236,9 @@ def detect_test_framework(curdir, tests_root) -> Optional[str]:
|
|||
"setup.cfg": "[tool:pytest]",
|
||||
}
|
||||
for pytest_file in pytest_files:
|
||||
file_path = os.path.join(curdir, pytest_file)
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, encoding="utf8") as file:
|
||||
file_path = curdir / pytest_file
|
||||
if file_path.exists():
|
||||
with file_path.open(encoding="utf8") as file:
|
||||
contents = file.read()
|
||||
if pytest_config_patterns[pytest_file] in contents:
|
||||
test_framework = "pytest"
|
||||
|
|
@ -244,9 +246,9 @@ def detect_test_framework(curdir, tests_root) -> Optional[str]:
|
|||
test_framework = "pytest"
|
||||
else:
|
||||
# Check if any python files contain a class that inherits from unittest.TestCase
|
||||
for filename in os.listdir(tests_root):
|
||||
if filename.endswith(".py"):
|
||||
with open(os.path.join(tests_root, filename), encoding="utf8") as file:
|
||||
for filename in tests_root.iterdir():
|
||||
if filename.suffix == ".py":
|
||||
with filename.open(encoding="utf8") as file:
|
||||
contents = file.read()
|
||||
try:
|
||||
node = ast.parse(contents)
|
||||
|
|
@ -271,14 +273,13 @@ def detect_test_framework(curdir, tests_root) -> Optional[str]:
|
|||
def check_for_toml_or_setup_file() -> Optional[str]:
|
||||
click.echo()
|
||||
click.echo("Checking for pyproject.toml or setup.py ...\r", nl=False)
|
||||
curdir = os.getcwd()
|
||||
pyproject_toml_path = os.path.join(curdir, "pyproject.toml")
|
||||
setup_py_path = os.path.join(curdir, "setup.py")
|
||||
curdir = Path.cwd()
|
||||
pyproject_toml_path = curdir / "pyproject.toml"
|
||||
setup_py_path = curdir / "setup.py"
|
||||
project_name = None
|
||||
if os.path.exists(pyproject_toml_path):
|
||||
if pyproject_toml_path.exists():
|
||||
try:
|
||||
with open(pyproject_toml_path, encoding="utf8") as f:
|
||||
pyproject_toml_content = f.read()
|
||||
pyproject_toml_content = pyproject_toml_path.read_text(encoding="utf8")
|
||||
project_name = tomlkit.parse(pyproject_toml_content)["tool"]["poetry"]["name"]
|
||||
click.echo(f"✅ I found a pyproject.toml for your project {project_name}.")
|
||||
ph("cli-pyproject-toml-found-name")
|
||||
|
|
@ -286,9 +287,8 @@ def check_for_toml_or_setup_file() -> Optional[str]:
|
|||
click.echo("✅ I found a pyproject.toml for your project.")
|
||||
ph("cli-pyproject-toml-found")
|
||||
else:
|
||||
if os.path.exists(setup_py_path):
|
||||
with open(setup_py_path, encoding="utf8") as f:
|
||||
setup_py_content = f.read()
|
||||
if setup_py_path.exists():
|
||||
setup_py_content = setup_py_path.read_text(encoding="utf8")
|
||||
project_name_match = re.search(
|
||||
r"setup\s*\([^)]*?name\s*=\s*['\"](.*?)['\"]",
|
||||
setup_py_content,
|
||||
|
|
@ -321,11 +321,10 @@ def check_for_toml_or_setup_file() -> Optional[str]:
|
|||
new_pyproject_toml = tomlkit.document()
|
||||
new_pyproject_toml["tool"] = {"codeflash": {}}
|
||||
try:
|
||||
with open(pyproject_toml_path, "w", encoding="utf8") as pyproject_file:
|
||||
pyproject_file.write(tomlkit.dumps(new_pyproject_toml))
|
||||
pyproject_toml_path.write_text(tomlkit.dumps(new_pyproject_toml), encoding="utf8")
|
||||
|
||||
# Check if the pyproject.toml file was created
|
||||
if os.path.exists(pyproject_toml_path):
|
||||
if pyproject_toml_path.exists():
|
||||
click.echo(
|
||||
f"✅ Created a pyproject.toml file at {pyproject_toml_path}",
|
||||
)
|
||||
|
|
@ -356,9 +355,9 @@ def install_github_actions() -> None:
|
|||
owner, repo_name = get_repo_owner_and_name(repo)
|
||||
require_github_app_or_exit(owner, repo_name)
|
||||
|
||||
git_root = repo.git.rev_parse("--show-toplevel")
|
||||
workflows_path = os.path.join(git_root, ".github", "workflows")
|
||||
optimize_yaml_path = os.path.join(workflows_path, "codeflash-optimize.yaml")
|
||||
git_root = Path(repo.git.rev_parse("--show-toplevel"))
|
||||
workflows_path = git_root / ".github" / "workflows"
|
||||
optimize_yaml_path = workflows_path / "codeflash-optimize.yaml"
|
||||
|
||||
confirm_creation_yes = inquirer_wrapper(
|
||||
inquirer.confirm,
|
||||
|
|
@ -373,7 +372,7 @@ def install_github_actions() -> None:
|
|||
click.echo("⏩️ Exiting workflow creation.")
|
||||
ph("cli-github-workflow-skipped")
|
||||
apologize_and_exit()
|
||||
os.makedirs(workflows_path, exist_ok=True)
|
||||
workflows_path.mkdir(parents=True, exist_ok=True)
|
||||
from importlib.resources import files
|
||||
|
||||
py_version = sys.version_info
|
||||
|
|
@ -387,13 +386,13 @@ def install_github_actions() -> None:
|
|||
"{{ python_version }}",
|
||||
python_version_string,
|
||||
)
|
||||
with open(optimize_yaml_path, "w", encoding="utf8") as optimize_yml_file:
|
||||
with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file:
|
||||
optimize_yml_file.write(optimize_yml_content)
|
||||
click.echo(f"✅ Created {optimize_yaml_path}{LF}")
|
||||
click.prompt(
|
||||
f"Next, you'll need to add your CODEFLASH_API_KEY as a secret to your GitHub repo.{LF}"
|
||||
+ f"Press Enter to open your repo's secrets page at {get_github_secrets_page_url(repo)} ...{LF}"
|
||||
+ f"Then, click 'New repository secret' to add your api key with the variable name CODEFLASH_API_KEY.{LF}",
|
||||
f"Press Enter to open your repo's secrets page at {get_github_secrets_page_url(repo)} ...{LF}"
|
||||
f"Then, click 'New repository secret' to add your api key with the variable name CODEFLASH_API_KEY.{LF}",
|
||||
default="",
|
||||
type=click.STRING,
|
||||
prompt_suffix="",
|
||||
|
|
@ -402,7 +401,7 @@ def install_github_actions() -> None:
|
|||
click.launch(get_github_secrets_page_url(repo))
|
||||
click.echo(
|
||||
"🐙 I opened your Github secrets page! Note: if you see a 404, you probably don't have access to this "
|
||||
+ "repo's secrets; ask a repo admin to add it for you, or (not super recommended) you can temporarily "
|
||||
"repo's secrets; ask a repo admin to add it for you, or (not super recommended) you can temporarily "
|
||||
f"hard-code your api key into the workflow file.{LF}",
|
||||
)
|
||||
click.pause()
|
||||
|
|
@ -410,7 +409,7 @@ def install_github_actions() -> None:
|
|||
click.prompt(
|
||||
f"Finally, for the workflow to work, you'll need to edit the workflow file to install the right "
|
||||
f"Python version and any project dependencies.{LF}"
|
||||
+ f"Press Enter to open {optimize_yaml_path} in your editor.{LF}",
|
||||
f"Press Enter to open {optimize_yaml_path} in your editor.{LF}",
|
||||
default="",
|
||||
type=click.STRING,
|
||||
prompt_suffix="",
|
||||
|
|
@ -419,7 +418,7 @@ def install_github_actions() -> None:
|
|||
click.launch(optimize_yaml_path)
|
||||
click.echo(
|
||||
"📝 I opened the workflow file in your editor! You'll need to edit the steps that install the right Python "
|
||||
+ f"version and any project dependencies. See the comments in the file for more details.{LF}",
|
||||
f"version and any project dependencies. See the comments in the file for more details.{LF}",
|
||||
)
|
||||
click.pause()
|
||||
click.echo()
|
||||
|
|
@ -434,9 +433,9 @@ def install_github_actions() -> None:
|
|||
|
||||
# Create or update the pyproject.toml file with the Codeflash dependency & configuration
|
||||
def configure_pyproject_toml(setup_info: SetupInfo) -> None:
|
||||
toml_path = os.path.join(os.getcwd(), "pyproject.toml")
|
||||
toml_path = Path.cwd() / "pyproject.toml"
|
||||
try:
|
||||
with open(toml_path, encoding="utf8") as pyproject_file:
|
||||
with toml_path.open(encoding="utf8") as pyproject_file:
|
||||
pyproject_data = tomlkit.parse(pyproject_file.read())
|
||||
except FileNotFoundError:
|
||||
click.echo(
|
||||
|
|
@ -470,7 +469,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
|
|||
pyproject_data["tool"] = tool_section
|
||||
|
||||
click.echo("Writing Codeflash configuration ...\r", nl=False)
|
||||
with open(toml_path, "w", encoding="utf8") as pyproject_file:
|
||||
with toml_path.open("w", encoding="utf8") as pyproject_file:
|
||||
pyproject_file.write(tomlkit.dumps(pyproject_data))
|
||||
click.echo(f"✅ Added Codeflash configuration to {toml_path}")
|
||||
click.echo()
|
||||
|
|
@ -488,8 +487,8 @@ def install_github_app() -> None:
|
|||
else:
|
||||
click.prompt(
|
||||
f"Finally, you'll need install the Codeflash GitHub app by choosing the repository you want to install Codeflash on.{LF}"
|
||||
+ f"I will attempt to open the github app page - https://github.com/apps/codeflash-ai/installations/select_target {LF}"
|
||||
+ f"Press Enter to open the page to let you install the app ...{LF}",
|
||||
f"I will attempt to open the github app page - https://github.com/apps/codeflash-ai/installations/select_target {LF}"
|
||||
f"Press Enter to open the page to let you install the app ...{LF}",
|
||||
default="",
|
||||
type=click.STRING,
|
||||
prompt_suffix="",
|
||||
|
|
@ -601,7 +600,7 @@ def enter_api_key_and_save_to_rc() -> None:
|
|||
os.environ["CODEFLASH_API_KEY"] = api_key
|
||||
|
||||
|
||||
def create_bubble_sort_file_and_test(args: Namespace) -> None:
|
||||
def create_bubble_sort_file_and_test(args: Namespace) -> tuple[str, str]:
|
||||
bubble_sort_content = """def sorter(arr):
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr) - 1):
|
||||
|
|
@ -613,10 +612,10 @@ def create_bubble_sort_file_and_test(args: Namespace) -> None:
|
|||
"""
|
||||
if args.test_framework == "unittest":
|
||||
bubble_sort_test_content = f"""import unittest
|
||||
from {os.path.basename(args.module_root)}.bubble_sort import sorter
|
||||
from {Path(args.module_root).name}.bubble_sort import sorter
|
||||
|
||||
class TestBubbleSort(unittest.TestCase):
|
||||
def test_sort(self):
|
||||
class TestBubbleSort(unittest.TestCase):
|
||||
def test_sort(self):
|
||||
input = [5, 4, 3, 2, 1, 0]
|
||||
output = sorter(input)
|
||||
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
|
||||
|
|
@ -628,9 +627,9 @@ class TestBubbleSort(unittest.TestCase):
|
|||
input = list(reversed(range(100)))
|
||||
output = sorter(input)
|
||||
self.assertEqual(output, list(range(100)))
|
||||
"""
|
||||
"""
|
||||
elif args.test_framework == "pytest":
|
||||
bubble_sort_test_content = f"""from {os.path.basename(args.module_root)}.bubble_sort import sorter
|
||||
bubble_sort_test_content = f"""from {Path(args.module_root).name}.bubble_sort import sorter
|
||||
|
||||
def test_sort():
|
||||
input = [5, 4, 3, 2, 1, 0]
|
||||
|
|
@ -648,15 +647,17 @@ def test_sort():
|
|||
else:
|
||||
click.echo(f"❌ Unsupported test framework: {args.test_framework}")
|
||||
apologize_and_exit()
|
||||
bubble_sort_path = os.path.join(args.module_root, "bubble_sort.py")
|
||||
with open(bubble_sort_path, "w", encoding="utf8") as bubble_sort_file:
|
||||
bubble_sort_file.write(bubble_sort_content)
|
||||
bubble_sort_test_path = os.path.join(args.tests_root, "test_bubble_sort.py")
|
||||
with open(bubble_sort_test_path, "w", encoding="utf8") as bubble_sort_test_file:
|
||||
bubble_sort_test_file.write(bubble_sort_test_content)
|
||||
|
||||
bubble_sort_path = Path(args.module_root) / "bubble_sort.py"
|
||||
bubble_sort_path.write_text(bubble_sort_content, encoding="utf8")
|
||||
|
||||
bubble_sort_test_path = Path(args.tests_root) / "test_bubble_sort.py"
|
||||
bubble_sort_test_path.write_text(bubble_sort_test_content, encoding="utf8")
|
||||
|
||||
click.echo(f"✅ Created {bubble_sort_path}")
|
||||
click.echo(f"✅ Created {bubble_sort_test_path}")
|
||||
return bubble_sort_path, bubble_sort_test_path
|
||||
|
||||
return str(bubble_sort_path), str(bubble_sort_test_path)
|
||||
|
||||
|
||||
def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test_path: str) -> None:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import libcst as cst
|
||||
|
|
@ -9,8 +10,8 @@ from libcst.codemod import CodemodContext
|
|||
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
|
||||
from libcst.helpers import calculate_module_and_package
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libcst.helpers import ModuleNameAndPackage
|
||||
|
|
@ -43,9 +44,9 @@ def delete___future___aliased_imports(module_code: str) -> str:
|
|||
def add_needed_imports_from_module(
|
||||
src_module_code: str,
|
||||
dst_module_code: str,
|
||||
src_path: str,
|
||||
dst_path: str,
|
||||
project_root: str,
|
||||
src_path: Path,
|
||||
dst_path: Path,
|
||||
project_root: Path,
|
||||
helper_functions: list[FunctionSource] | None = None,
|
||||
) -> str:
|
||||
"""Add all needed and used source module code imports to the destination module code, and return it."""
|
||||
|
|
@ -57,13 +58,13 @@ def add_needed_imports_from_module(
|
|||
dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path)
|
||||
|
||||
dst_context: CodemodContext = CodemodContext(
|
||||
filename=src_path,
|
||||
filename=src_path.name,
|
||||
full_module_name=dst_module_and_package.name,
|
||||
full_package_name=dst_module_and_package.package,
|
||||
)
|
||||
gatherer: GatherImportsVisitor = GatherImportsVisitor(
|
||||
CodemodContext(
|
||||
filename=src_path,
|
||||
filename=src_path.name,
|
||||
full_module_name=src_module_and_package.name,
|
||||
full_package_name=src_module_and_package.package,
|
||||
),
|
||||
|
|
@ -129,7 +130,7 @@ def get_code(
|
|||
):
|
||||
return None, set()
|
||||
|
||||
file_path: str = functions_to_optimize[0].file_path
|
||||
file_path: Path = functions_to_optimize[0].file_path
|
||||
class_skeleton: set[tuple[int, int | None]] = set()
|
||||
contextual_dunder_methods: set[tuple[str, str]] = set()
|
||||
target_code: str = ""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from typing import IO
|
||||
from typing import IO, TYPE_CHECKING
|
||||
|
||||
import libcst as cst
|
||||
from libcst import FunctionDef
|
||||
|
|
@ -9,6 +9,9 @@ from libcst import FunctionDef
|
|||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class OptimFunctionCollector(cst.CSTVisitor):
|
||||
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
|
||||
|
|
@ -200,11 +203,11 @@ def replace_functions_and_add_imports(
|
|||
source_code: str,
|
||||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
file_path_of_module_with_function_to_optimize: str,
|
||||
module_abspath: str,
|
||||
file_path_of_module_with_function_to_optimize: Path,
|
||||
module_abspath: Path,
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
project_root_path: str,
|
||||
project_root_path: Path,
|
||||
) -> str:
|
||||
return add_needed_imports_from_module(
|
||||
optimized_code,
|
||||
|
|
@ -224,11 +227,11 @@ def replace_functions_and_add_imports(
|
|||
def replace_function_definitions_in_module(
|
||||
function_names: list[str],
|
||||
optimized_code: str,
|
||||
file_path_of_module_with_function_to_optimize: str,
|
||||
module_abspath: str,
|
||||
file_path_of_module_with_function_to_optimize: Path,
|
||||
module_abspath: Path,
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]],
|
||||
contextual_functions: set[tuple[str, str]],
|
||||
project_root_path: str,
|
||||
project_root_path: Path,
|
||||
) -> bool:
|
||||
""":param function_names: List of qualified (not fully qualified) function names (function_name or
|
||||
class_name.method_name).
|
||||
|
|
@ -241,8 +244,7 @@ def replace_function_definitions_in_module(
|
|||
:return:
|
||||
"""
|
||||
file: IO[str]
|
||||
with open(module_abspath, encoding="utf8") as file:
|
||||
source_code: str = file.read()
|
||||
source_code: str = module_abspath.read_text(encoding="utf8")
|
||||
new_code: str = replace_functions_and_add_imports(
|
||||
source_code,
|
||||
function_names,
|
||||
|
|
@ -255,8 +257,7 @@ def replace_function_definitions_in_module(
|
|||
)
|
||||
if is_zero_diff(source_code, new_code):
|
||||
return False
|
||||
with open(module_abspath, "w", encoding="utf8") as file:
|
||||
file.write(new_code)
|
||||
module_abspath.write_text(new_code, encoding="utf8")
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,27 +1,26 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from codeflash.cli_cmds.console import logger
|
||||
import os
|
||||
import site
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
|
||||
def module_name_from_file_path(file_path: str, project_root_path: str) -> str:
|
||||
relative_path = os.path.relpath(file_path, project_root_path)
|
||||
module_path = relative_path.replace(os.sep, ".")
|
||||
if module_path.lower().endswith(".py"):
|
||||
module_path = module_path[:-3]
|
||||
return module_path
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def file_path_from_module_name(module_name: str, project_root_path: str) -> str:
|
||||
"""Get file path from module path"""
|
||||
return os.path.join(project_root_path, module_name.replace(".", os.sep) + ".py")
|
||||
def module_name_from_file_path(file_path: Path, project_root_path: Path) -> str:
|
||||
relative_path = file_path.relative_to(project_root_path)
|
||||
return relative_path.with_suffix("").as_posix().replace("/", ".")
|
||||
|
||||
|
||||
def file_path_from_module_name(module_name: str, project_root_path: Path) -> Path:
|
||||
"""Get file path from module path."""
|
||||
return project_root_path / (module_name.replace(".", os.sep) + ".py")
|
||||
|
||||
|
||||
def get_imports_from_file(
|
||||
file_path: str | None = None,
|
||||
file_path: Path | None = None,
|
||||
file_string: str | None = None,
|
||||
file_ast: ast.AST | None = None,
|
||||
) -> list[ast.Import | ast.ImportFrom]:
|
||||
|
|
@ -29,19 +28,18 @@ def get_imports_from_file(
|
|||
sum([file_path is not None, file_string is not None, file_ast is not None]) == 1
|
||||
), "Must provide exactly one of file_path, file_string, or file_ast"
|
||||
if file_path:
|
||||
with open(file_path, encoding="utf8") as file:
|
||||
with file_path.open(encoding="utf8") as file:
|
||||
file_string = file.read()
|
||||
if file_ast is None:
|
||||
if file_string is None:
|
||||
logger.error("file_string cannot be None when file_ast is not provided")
|
||||
return []
|
||||
try:
|
||||
file_ast = ast.parse(file_string)
|
||||
except SyntaxError as e:
|
||||
logger.exception(f"Syntax error in code: {e}")
|
||||
return []
|
||||
imports = []
|
||||
for node in ast.walk(file_ast):
|
||||
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
||||
imports.append(node)
|
||||
return imports
|
||||
return [node for node in ast.walk(file_ast) if isinstance(node, (ast.Import, ast.ImportFrom))]
|
||||
|
||||
|
||||
def get_all_function_names(code: str) -> tuple[bool, list[str]]:
|
||||
|
|
@ -51,34 +49,27 @@ def get_all_function_names(code: str) -> tuple[bool, list[str]]:
|
|||
logger.exception(f"Syntax error in code: {e}")
|
||||
return False, []
|
||||
|
||||
function_names = []
|
||||
for node in ast.walk(module):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
function_names.append(node.name)
|
||||
function_names = [
|
||||
node.name for node in ast.walk(module) if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
||||
]
|
||||
return True, function_names
|
||||
|
||||
|
||||
def get_run_tmp_file(file_path: str) -> str:
|
||||
def get_run_tmp_file(file_path: Path) -> Path:
|
||||
if not hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
|
||||
return os.path.join(get_run_tmp_file.tmpdir.name, file_path)
|
||||
return Path(get_run_tmp_file.tmpdir.name) / file_path
|
||||
|
||||
|
||||
def path_belongs_to_site_packages(file_path: str) -> bool:
|
||||
site_packages = site.getsitepackages()
|
||||
for site_package_path in site_packages:
|
||||
if file_path.startswith(site_package_path + os.sep):
|
||||
return True
|
||||
return False
|
||||
def path_belongs_to_site_packages(file_path: Path) -> bool:
|
||||
site_packages = [Path(p) for p in site.getsitepackages()]
|
||||
return any(file_path.resolve().is_relative_to(site_package_path) for site_package_path in site_packages)
|
||||
|
||||
|
||||
def is_class_defined_in_file(class_name: str, file_path: str) -> bool:
|
||||
if not os.path.exists(file_path):
|
||||
def is_class_defined_in_file(class_name: str, file_path: Path) -> bool:
|
||||
if not file_path.exists():
|
||||
return False
|
||||
with open(file_path) as file:
|
||||
with file_path.open(encoding="utf8") as file:
|
||||
source = file.read()
|
||||
tree = ast.parse(source)
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef) and node.name == class_name:
|
||||
return True
|
||||
return False
|
||||
return any(isinstance(node, ast.ClassDef) and node.name == class_name for node in ast.walk(tree))
|
||||
|
|
|
|||
|
|
@ -1,56 +1,53 @@
|
|||
import os
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import tomlkit
|
||||
|
||||
|
||||
def find_pyproject_toml(config_file=None):
|
||||
def find_pyproject_toml(config_file: Path | None = None) -> Path:
|
||||
# Find the pyproject.toml file on the root of the project
|
||||
|
||||
if config_file is not None:
|
||||
if not config_file.lower().endswith(".toml"):
|
||||
raise ValueError(
|
||||
f"Config file {config_file} is not a valid toml file. Please recheck the path to pyproject.toml",
|
||||
)
|
||||
if not os.path.exists(config_file):
|
||||
raise ValueError(
|
||||
f"Config file {config_file} does not exist. Please recheck the path to pyproject.toml",
|
||||
)
|
||||
config_file = Path(config_file)
|
||||
if config_file.suffix.lower() != ".toml":
|
||||
msg = f"Config file {config_file} is not a valid toml file. Please recheck the path to pyproject.toml"
|
||||
raise ValueError(msg)
|
||||
if not config_file.exists():
|
||||
msg = f"Config file {config_file} does not exist. Please recheck the path to pyproject.toml"
|
||||
raise ValueError(msg)
|
||||
return config_file
|
||||
|
||||
else:
|
||||
dir_path = os.getcwd()
|
||||
dir_path = Path.cwd()
|
||||
|
||||
while os.path.dirname(dir_path) != dir_path:
|
||||
config_file = os.path.join(dir_path, "pyproject.toml")
|
||||
if os.path.exists(config_file):
|
||||
while dir_path != dir_path.parent:
|
||||
config_file = dir_path / "pyproject.toml"
|
||||
if config_file.exists():
|
||||
return config_file
|
||||
# Search for pyproject.toml in the parent directories
|
||||
dir_path = os.path.dirname(dir_path)
|
||||
raise ValueError(
|
||||
f"Could not find pyproject.toml in the current directory {os.getcwd()} or any of the parent directories. Please create it by running `poetry init`, or pass the path to pyproject.toml with the --config-file argument.",
|
||||
)
|
||||
dir_path = dir_path.parent
|
||||
msg = f"Could not find pyproject.toml in the current directory {Path.cwd()} or any of the parent directories. Please create it by running `poetry init`, or pass the path to pyproject.toml with the --config-file argument."
|
||||
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def parse_config_file(config_file_path: str = None) -> tuple[dict[str, Any], str]:
|
||||
def parse_config_file(config_file_path: Path | None = None) -> tuple[dict[str, Any], Path]:
|
||||
config_file_path = find_pyproject_toml(config_file_path)
|
||||
try:
|
||||
with open(config_file_path, "rb") as f:
|
||||
with config_file_path.open("rb") as f:
|
||||
data = tomlkit.parse(f.read())
|
||||
except tomlkit.exceptions.ParseError as e:
|
||||
raise ValueError(
|
||||
f"Error while parsing the config file {config_file_path}. Please recheck the file for syntax errors. Error: {e}",
|
||||
)
|
||||
msg = f"Error while parsing the config file {config_file_path}. Please recheck the file for syntax errors. Error: {e}"
|
||||
raise ValueError(msg) from e
|
||||
|
||||
try:
|
||||
tool = data["tool"]
|
||||
assert isinstance(tool, dict)
|
||||
config = tool["codeflash"]
|
||||
except tomlkit.exceptions.NonExistentKey:
|
||||
raise ValueError(
|
||||
f"Could not find the 'codeflash' block in the config file {config_file_path}. "
|
||||
f"Please run 'codeflash init' to create the config file.",
|
||||
)
|
||||
except tomlkit.exceptions.NonExistentKey as e:
|
||||
msg = f"Could not find the 'codeflash' block in the config file {config_file_path}. Please run 'codeflash init' to create the config file."
|
||||
raise ValueError(msg) from e
|
||||
assert isinstance(config, dict)
|
||||
|
||||
# default values:
|
||||
|
|
@ -79,9 +76,7 @@ def parse_config_file(config_file_path: str = None) -> tuple[dict[str, Any], str
|
|||
config[key] = bool_keys[key]
|
||||
for key in path_keys:
|
||||
if key in config:
|
||||
config[key] = os.path.realpath(
|
||||
os.path.join(os.path.dirname(config_file_path), config[key]),
|
||||
)
|
||||
config[key] = str((Path(config_file_path).parent / Path(config[key])).resolve())
|
||||
for key in list_str_keys:
|
||||
if key in config:
|
||||
config[key] = [str(cmd) for cmd in config[key]]
|
||||
|
|
@ -90,10 +85,7 @@ def parse_config_file(config_file_path: str = None) -> tuple[dict[str, Any], str
|
|||
|
||||
for key in path_list_keys:
|
||||
if key in config:
|
||||
config[key] = [
|
||||
os.path.realpath(os.path.join(os.path.dirname(config_file_path), path))
|
||||
for path in config[key]
|
||||
]
|
||||
config[key] = [(Path(config_file_path).parent / path).resolve() for path in config[key]]
|
||||
else: # Default to empty list
|
||||
config[key] = []
|
||||
|
||||
|
|
|
|||
|
|
@ -1,28 +1,31 @@
|
|||
from codeflash.cli_cmds.console import logger
|
||||
import os.path
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import isort
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
||||
|
||||
def format_code(
|
||||
formatter_cmds: list[str],
|
||||
path: str,
|
||||
) -> str:
|
||||
path: Path,
|
||||
) -> str | None:
|
||||
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
|
||||
if not os.path.exists(path):
|
||||
if not path.exists():
|
||||
logger.error(f"File {path} does not exist. Cannot format the file.")
|
||||
return None
|
||||
if formatter_cmds[0].lower() == "disabled":
|
||||
with open(path, encoding="utf8") as f:
|
||||
new_code = f.read()
|
||||
new_code = path.read_text(encoding="utf8")
|
||||
return new_code
|
||||
file_token = "$file"
|
||||
|
||||
for command in formatter_cmds:
|
||||
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
|
||||
formatter_cmd_list = [path if chunk == file_token else chunk for chunk in formatter_cmd_list]
|
||||
formatter_cmd_list = [str(path) if chunk == file_token else chunk for chunk in formatter_cmd_list]
|
||||
logger.info(f"Formatting code with {' '.join(formatter_cmd_list)} ...")
|
||||
|
||||
try:
|
||||
|
|
@ -40,8 +43,7 @@ def format_code(
|
|||
else:
|
||||
logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}")
|
||||
|
||||
with open(path, encoding="utf8") as f:
|
||||
new_code = f.read()
|
||||
new_code = path.read_text(encoding="utf8")
|
||||
|
||||
return new_code
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
from __future__ import annotations
|
||||
from codeflash.cli_cmds.console import logger
|
||||
import os
|
||||
|
||||
import sys
|
||||
import time
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import git
|
||||
|
|
@ -12,10 +12,11 @@ from git import Repo
|
|||
from unidiff import PatchSet
|
||||
|
||||
from codeflash.cli_cmds.cli_common import inquirer_wrapper
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
||||
|
||||
def get_git_diff(
|
||||
repo_directory: str = os.getcwd(),
|
||||
repo_directory: Path = Path.cwd(),
|
||||
uncommitted_changes: bool = False,
|
||||
) -> dict[str, list[int]]:
|
||||
repository = git.Repo(repo_directory, search_parent_directories=True)
|
||||
|
|
@ -37,11 +38,12 @@ def get_git_diff(
|
|||
patch_set = PatchSet(StringIO(uni_diff_text))
|
||||
change_list: dict[str, list[int]] = {} # list of changes
|
||||
for patched_file in patch_set:
|
||||
file_path: str = patched_file.path # file name
|
||||
if not file_path.endswith(".py"):
|
||||
file_path: Path = Path(patched_file.path)
|
||||
if file_path.suffix != ".py":
|
||||
continue
|
||||
file_path = os.path.join(repository.working_dir, file_path)
|
||||
logger.debug("file name :" + file_path)
|
||||
file_path = Path(repository.working_dir) / file_path
|
||||
logger.debug(f"file name: {file_path}")
|
||||
|
||||
add_line_no: list[int] = [
|
||||
line.target_line_no
|
||||
for hunk in patched_file
|
||||
|
|
@ -49,15 +51,16 @@ def get_git_diff(
|
|||
if line.is_added and line.value.strip() != ""
|
||||
] # the row number of deleted lines
|
||||
|
||||
logger.debug("added lines : " + str(add_line_no))
|
||||
logger.debug(f"added lines: {add_line_no}")
|
||||
|
||||
del_line_no: list[int] = [
|
||||
line.source_line_no
|
||||
for hunk in patched_file
|
||||
for line in hunk
|
||||
if line.is_removed and line.value.strip() != ""
|
||||
] # the row number of added liens
|
||||
] # the row number of added lines
|
||||
|
||||
logger.debug("deleted lines : " + str(del_line_no))
|
||||
logger.debug(f"deleted lines: {del_line_no}")
|
||||
|
||||
change_list[file_path] = add_line_no
|
||||
return change_list
|
||||
|
|
|
|||
|
|
@ -1,15 +1,16 @@
|
|||
from codeflash.cli_cmds.console import logger
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from multiprocessing import Process, Queue
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import jedi
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import module_name_from_file_path
|
||||
from codeflash.verification.test_results import TestType
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
|
@ -139,7 +140,7 @@ def discover_tests_unittest(
|
|||
cfg: TestConfig,
|
||||
discover_only_these_tests: Optional[List[str]] = None,
|
||||
) -> Dict[str, List[FunctionCalledInTest]]:
|
||||
tests_root = cfg.tests_root
|
||||
tests_root = Path(cfg.tests_root)
|
||||
loader = unittest.TestLoader()
|
||||
tests = loader.discover(str(tests_root))
|
||||
file_to_test_map = defaultdict(list)
|
||||
|
|
@ -151,18 +152,18 @@ def discover_tests_unittest(
|
|||
_test.__class__.__qualname__,
|
||||
)
|
||||
|
||||
_test_module_path = _test_module.replace(".", os.sep)
|
||||
_test_module_path = os.path.normpath(os.path.join(str(tests_root), _test_module_path) + ".py")
|
||||
if not os.path.exists(_test_module_path) or (
|
||||
discover_only_these_tests and _test_module_path not in discover_only_these_tests
|
||||
_test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py")
|
||||
_test_module_path = tests_root / _test_module_path
|
||||
if not _test_module_path.exists() or (
|
||||
discover_only_these_tests and str(_test_module_path) not in discover_only_these_tests
|
||||
):
|
||||
return None
|
||||
if "__replay_test" in _test_module_path:
|
||||
if "__replay_test" in str(_test_module_path):
|
||||
test_type = TestType.REPLAY_TEST
|
||||
else:
|
||||
test_type = TestType.EXISTING_UNIT_TEST
|
||||
return TestsInFile(
|
||||
test_file=_test_module_path,
|
||||
test_file=str(_test_module_path),
|
||||
test_suite=_test_suite_name,
|
||||
test_function=_test_function,
|
||||
test_type=test_type,
|
||||
|
|
@ -186,11 +187,11 @@ def discover_tests_unittest(
|
|||
continue
|
||||
details = get_test_details(test_2)
|
||||
if details is not None:
|
||||
file_to_test_map[details.test_file].append(details)
|
||||
file_to_test_map[str(details.test_file)].append(details)
|
||||
else:
|
||||
details = get_test_details(test)
|
||||
if details is not None:
|
||||
file_to_test_map[details.test_file].append(details)
|
||||
file_to_test_map[str(details.test_file)].append(details)
|
||||
return process_test_files(file_to_test_map, cfg)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ class FunctionToOptimize:
|
|||
"""
|
||||
|
||||
function_name: str
|
||||
file_path: str
|
||||
file_path: Path
|
||||
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
|
||||
starting_line: Optional[int] = None
|
||||
ending_line: Optional[int] = None
|
||||
|
|
@ -150,20 +150,20 @@ class FunctionToOptimize:
|
|||
def qualified_name(self) -> str:
|
||||
return self.function_name if self.parents == [] else f"{self.parents[0].name}.{self.function_name}"
|
||||
|
||||
def qualified_name_with_modules_from_root(self, project_root_path: str) -> str:
|
||||
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
|
||||
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
|
||||
|
||||
|
||||
def get_functions_to_optimize(
|
||||
optimize_all: str | None,
|
||||
replay_test: str | None,
|
||||
file: str | None,
|
||||
file: Path | None,
|
||||
only_get_this_function: str | None,
|
||||
test_cfg: TestConfig,
|
||||
ignore_paths: list[str],
|
||||
project_root: str,
|
||||
module_root: str,
|
||||
) -> tuple[dict[str, list[FunctionToOptimize]], int]:
|
||||
ignore_paths: list[Path],
|
||||
project_root: Path,
|
||||
module_root: Path,
|
||||
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
|
||||
assert (
|
||||
sum(
|
||||
[ # Ensure only one of the options is provided
|
||||
|
|
@ -191,9 +191,8 @@ def get_functions_to_optimize(
|
|||
if only_get_this_function is not None:
|
||||
split_function = only_get_this_function.split(".")
|
||||
if len(split_function) > 2:
|
||||
raise ValueError(
|
||||
"Function name should be in the format 'function_name' or 'class_name.function_name'",
|
||||
)
|
||||
msg = "Function name should be in the format 'function_name' or 'class_name.function_name'"
|
||||
raise ValueError(msg)
|
||||
if len(split_function) == 2:
|
||||
class_name, only_function_name = split_function
|
||||
else:
|
||||
|
|
@ -206,10 +205,8 @@ def get_functions_to_optimize(
|
|||
):
|
||||
found_function = fn
|
||||
if found_function is None:
|
||||
raise ValueError(
|
||||
f"Function {only_function_name} not found in file {file} or"
|
||||
f" the function does not have a 'return' statement.",
|
||||
)
|
||||
msg = f"Function {only_function_name} not found in file {file} or the function does not have a 'return' statement."
|
||||
raise ValueError(msg)
|
||||
functions[file] = [found_function]
|
||||
else:
|
||||
logger.info("Finding all functions modified in the current git diff ...")
|
||||
|
|
@ -229,64 +226,60 @@ def get_functions_to_optimize(
|
|||
def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]:
|
||||
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=False)
|
||||
modified_functions: dict[str, list[FunctionToOptimize]] = {}
|
||||
for path in modified_lines:
|
||||
if not os.path.exists(path):
|
||||
for path_str in modified_lines:
|
||||
path = Path(path_str)
|
||||
if not path.exists():
|
||||
continue
|
||||
with open(path, encoding="utf8") as f:
|
||||
with path.open(encoding="utf8") as f:
|
||||
file_content = f.read()
|
||||
try:
|
||||
wrapper = cst.metadata.MetadataWrapper(cst.parse_module(file_content))
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
continue
|
||||
function_lines = FunctionVisitor(file_path=path)
|
||||
function_lines = FunctionVisitor(file_path=str(path))
|
||||
wrapper.visit(function_lines)
|
||||
modified_functions[path] = [
|
||||
modified_functions[str(path)] = [
|
||||
function_to_optimize
|
||||
for function_to_optimize in function_lines.functions
|
||||
if (start_line := function_to_optimize.starting_line) is not None
|
||||
and (end_line := function_to_optimize.ending_line) is not None
|
||||
and any(start_line <= line <= end_line for line in modified_lines[path])
|
||||
and any(start_line <= line <= end_line for line in modified_lines[path_str])
|
||||
]
|
||||
return modified_functions
|
||||
|
||||
|
||||
def get_all_files_and_functions(module_root_path: str) -> dict[str, list[FunctionToOptimize]]:
|
||||
def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[FunctionToOptimize]]:
|
||||
functions: dict[str, list[FunctionToOptimize]] = {}
|
||||
for root, dirs, files in os.walk(module_root_path):
|
||||
for file in files:
|
||||
if not file.endswith(".py"):
|
||||
continue
|
||||
file_path = os.path.join(root, file)
|
||||
|
||||
# Find all the functions in the file
|
||||
functions.update(find_all_functions_in_file(file_path))
|
||||
module_root_path = Path(module_root_path)
|
||||
for file_path in module_root_path.rglob("*.py"):
|
||||
# Find all the functions in the file
|
||||
functions.update(find_all_functions_in_file(str(file_path)))
|
||||
# Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time.
|
||||
# Helpful if an optimize-all run is stuck and we restart it.
|
||||
files_list = list(functions.items())
|
||||
random.shuffle(files_list)
|
||||
functions_shuffled = dict(files_list)
|
||||
return functions_shuffled
|
||||
return dict(files_list)
|
||||
|
||||
|
||||
def find_all_functions_in_file(file_path: str) -> dict[str, list[FunctionToOptimize]]:
|
||||
def find_all_functions_in_file(file_path: Path) -> dict[str, list[FunctionToOptimize]]:
|
||||
functions: dict[str, list[FunctionToOptimize]] = {}
|
||||
with open(file_path, encoding="utf8") as f:
|
||||
with file_path.open(encoding="utf8") as f:
|
||||
try:
|
||||
ast_module = ast.parse(f.read())
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return functions
|
||||
function_name_visitor = FunctionWithReturnStatement(file_path)
|
||||
function_name_visitor = FunctionWithReturnStatement(str(file_path))
|
||||
function_name_visitor.visit(ast_module)
|
||||
functions[file_path] = function_name_visitor.functions
|
||||
functions[str(file_path)] = function_name_visitor.functions
|
||||
return functions
|
||||
|
||||
|
||||
def get_all_replay_test_functions(
|
||||
replay_test: str,
|
||||
test_cfg: TestConfig,
|
||||
project_root_path: str,
|
||||
project_root_path: Path,
|
||||
) -> dict[str, list[FunctionToOptimize]]:
|
||||
function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
|
||||
# Get the absolute file paths for each function, excluding class name if present
|
||||
|
|
@ -303,7 +296,7 @@ def get_all_replay_test_functions(
|
|||
if module_path_parts
|
||||
and is_class_defined_in_file(
|
||||
module_path_parts[-1],
|
||||
os.path.join(project_root_path, *module_path_parts[:-1]) + ".py",
|
||||
str(Path(project_root_path, *module_path_parts[:-1])) + ".py",
|
||||
)
|
||||
else None
|
||||
)
|
||||
|
|
@ -314,7 +307,7 @@ def get_all_replay_test_functions(
|
|||
else:
|
||||
function = function_name
|
||||
file_path_parts = module_path_parts
|
||||
file_path = os.path.join(project_root_path, *file_path_parts) + ".py"
|
||||
file_path = Path(project_root_path, *file_path_parts).with_suffix(".py")
|
||||
file_to_functions_map[file_path].append((function, function_name, class_name))
|
||||
for file_path, functions in file_to_functions_map.items():
|
||||
all_valid_functions: dict[str, list[FunctionToOptimize]] = find_all_functions_in_file(
|
||||
|
|
@ -349,8 +342,7 @@ def ignored_submodule_paths(module_root: str) -> list[str]:
|
|||
if is_git_repo(module_root):
|
||||
git_repo = git.Repo(module_root, search_parent_directories=True)
|
||||
return [
|
||||
os.path.realpath(os.path.join(git_repo.working_tree_dir, submodule.path))
|
||||
for submodule in git_repo.submodules
|
||||
Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules
|
||||
]
|
||||
return []
|
||||
|
||||
|
|
@ -358,7 +350,7 @@ def ignored_submodule_paths(module_root: str) -> list[str]:
|
|||
class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
||||
def __init__(
|
||||
self,
|
||||
file_name: str,
|
||||
file_name: Path,
|
||||
function_or_method_name: str,
|
||||
class_name: str | None = None,
|
||||
line_no: int | None = None,
|
||||
|
|
@ -418,7 +410,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
|||
|
||||
|
||||
def inspect_top_level_functions_or_methods(
|
||||
file_name: str,
|
||||
file_name: Path,
|
||||
function_or_method_name: str,
|
||||
class_name: str | None = None,
|
||||
line_no: int | None = None,
|
||||
|
|
@ -449,11 +441,11 @@ def inspect_top_level_functions_or_methods(
|
|||
def filter_functions(
|
||||
modified_functions: dict[str, list[FunctionToOptimize]],
|
||||
tests_root: str,
|
||||
ignore_paths: list[str],
|
||||
ignore_paths: list[Path],
|
||||
project_root: str,
|
||||
module_root: str,
|
||||
disable_logs: bool = False,
|
||||
) -> tuple[dict[str, list[FunctionToOptimize]], int]:
|
||||
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
|
||||
blocklist_funcs = get_blacklisted_functions()
|
||||
# Remove any function that we don't want to optimize
|
||||
|
||||
|
|
@ -522,37 +514,33 @@ def filter_functions(
|
|||
|
||||
|
||||
def filter_files_optimized(
|
||||
file_path: str,
|
||||
tests_root: str,
|
||||
ignore_paths: list[str],
|
||||
module_root: str,
|
||||
file_path: Path,
|
||||
tests_root: Path,
|
||||
ignore_paths: list[Path],
|
||||
module_root: Path,
|
||||
) -> bool:
|
||||
"""Optimized version of the filter_functions function above.
|
||||
|
||||
Takes in file paths and returns the count of files that are to be optimized.
|
||||
"""
|
||||
submodule_paths = None
|
||||
if file_path.startswith(tests_root + os.sep):
|
||||
if file_path.is_relative_to(tests_root):
|
||||
return False
|
||||
if file_path in ignore_paths or any(
|
||||
file_path.startswith(ignore_path + os.sep) for ignore_path in ignore_paths
|
||||
file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths
|
||||
):
|
||||
return False
|
||||
if path_belongs_to_site_packages(file_path):
|
||||
return False
|
||||
if not file_path.startswith(module_root + os.sep):
|
||||
if not file_path.is_relative_to(module_root):
|
||||
return False
|
||||
if submodule_paths is None:
|
||||
submodule_paths = ignored_submodule_paths(module_root)
|
||||
if file_path in submodule_paths or any(
|
||||
file_path.startswith(submodule_path + os.sep) for submodule_path in submodule_paths
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
return not (
|
||||
file_path in submodule_paths
|
||||
or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths)
|
||||
)
|
||||
|
||||
|
||||
def function_has_return_statement(function_node: Union[FunctionDef, AsyncFunctionDef]) -> bool:
|
||||
for node in ast.walk(function_node):
|
||||
if isinstance(node, ast.Return):
|
||||
return True
|
||||
return False
|
||||
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
|
||||
return any(isinstance(node, ast.Return) for node in ast.walk(function_node))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Iterator, Optional
|
||||
|
||||
from jedi.api.classes import Name
|
||||
|
|
@ -17,7 +18,7 @@ from codeflash.verification.test_results import TestResults, TestType
|
|||
|
||||
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
|
||||
class FunctionSource:
|
||||
file_path: str
|
||||
file_path: Path
|
||||
qualified_name: str
|
||||
fully_qualified_name: str
|
||||
only_function_name: str
|
||||
|
|
@ -55,7 +56,7 @@ class GeneratedTestsList(BaseModel):
|
|||
|
||||
|
||||
class TestFile(BaseModel):
|
||||
instrumented_file_path: str
|
||||
instrumented_file_path: Path
|
||||
original_file_path: Optional[str] = None
|
||||
original_source: Optional[str] = None
|
||||
test_type: TestType
|
||||
|
|
|
|||
|
|
@ -4,16 +4,20 @@ import ast
|
|||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import jedi
|
||||
import tiktoken
|
||||
from jedi.api.classes import Name
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_extractor import get_code
|
||||
from codeflash.code_utils.code_utils import module_name_from_file_path, path_belongs_to_site_packages
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
from codeflash.models.models import FunctionSource
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def belongs_to_class(name: Name, class_name: str) -> bool:
|
||||
|
|
@ -36,12 +40,11 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
|
|||
def get_type_annotation_context(
|
||||
function: FunctionToOptimize,
|
||||
jedi_script: jedi.Script,
|
||||
project_root_path: str,
|
||||
project_root_path: Path,
|
||||
) -> list[FunctionSource]:
|
||||
function_name: str = function.function_name
|
||||
file_path: str = function.file_path
|
||||
with open(file_path, encoding="utf8") as file:
|
||||
file_contents: str = file.read()
|
||||
file_path: Path = function.file_path
|
||||
file_contents: str = file_path.read_text(encoding="utf8")
|
||||
try:
|
||||
module: ast.Module = ast.parse(file_contents)
|
||||
except SyntaxError as e:
|
||||
|
|
@ -72,14 +75,15 @@ def get_type_annotation_context(
|
|||
logger.exception(f"Error while getting definition: {ex}")
|
||||
definition = []
|
||||
if definition: # TODO can be multiple definitions
|
||||
definition_path = str(definition[0].module_path)
|
||||
definition_path = definition[0].module_path
|
||||
# The definition is part of this project and not defined within the original function
|
||||
if (
|
||||
definition_path.startswith(project_root_path + os.sep)
|
||||
str(definition_path).startswith(str(project_root_path) + os.sep)
|
||||
and definition[0].full_name
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and not belongs_to_function(definition[0], function_name)
|
||||
):
|
||||
assert definition_path is not None
|
||||
source_code = get_code(
|
||||
[
|
||||
FunctionToOptimize(
|
||||
|
|
@ -164,7 +168,7 @@ def get_type_annotation_context(
|
|||
|
||||
def get_function_variables_definitions(
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root_path: str,
|
||||
project_root_path: Path,
|
||||
) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
|
||||
function_name = function_to_optimize.function_name
|
||||
file_path = function_to_optimize.file_path
|
||||
|
|
@ -203,10 +207,11 @@ def get_function_variables_definitions(
|
|||
if definitions:
|
||||
# TODO: there can be multiple definitions, see how to handle such cases
|
||||
definition = definitions[0]
|
||||
definition_path = str(definition.module_path)
|
||||
definition_path = definition.module_path
|
||||
assert definition_path is not None
|
||||
# The definition is part of this project and not defined within the original function
|
||||
if (
|
||||
definition_path.startswith(project_root_path + os.sep)
|
||||
str(definition_path).startswith(str(project_root_path) + os.sep)
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and definition.full_name
|
||||
and not belongs_to_function(definition, function_name)
|
||||
|
|
@ -255,7 +260,7 @@ def get_function_variables_definitions(
|
|||
for source in sources:
|
||||
if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names:
|
||||
if not source.qualified_name.count("."):
|
||||
no_parent_sources[source.file_path][source.qualified_name].add(source)
|
||||
no_parent_sources[str(source.file_path)][source.qualified_name].add(source)
|
||||
else:
|
||||
parent_sources.add(source)
|
||||
existing_fully_qualified_names.add(fully_qualified_name)
|
||||
|
|
@ -263,7 +268,7 @@ def get_function_variables_definitions(
|
|||
source
|
||||
for source in parent_sources
|
||||
if source.file_path not in no_parent_sources
|
||||
or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path]
|
||||
or source.qualified_name.rpartition(".")[0] not in no_parent_sources[str(source.file_path)]
|
||||
]
|
||||
deduped_no_parent_sources = [
|
||||
source
|
||||
|
|
@ -279,7 +284,7 @@ MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k
|
|||
|
||||
def get_constrained_function_context_and_helper_functions(
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root_path: str,
|
||||
project_root_path: Path,
|
||||
code_to_optimize: str,
|
||||
max_tokens: int = MAX_PROMPT_TOKENS,
|
||||
) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]:
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ from __future__ import annotations
|
|||
|
||||
import concurrent.futures
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import isort
|
||||
|
|
@ -623,7 +623,7 @@ class Optimizer:
|
|||
def replace_function_and_helpers_with_optimized_code(
|
||||
self,
|
||||
code_context: CodeOptimizationContext,
|
||||
function_to_optimize_file_path: str,
|
||||
function_to_optimize_file_path: Path,
|
||||
optimized_code: str,
|
||||
qualified_function_name: str,
|
||||
) -> bool:
|
||||
|
|
@ -660,7 +660,7 @@ class Optimizer:
|
|||
def get_code_optimization_context(
|
||||
self,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root: str,
|
||||
project_root: Path,
|
||||
original_source_code: str,
|
||||
) -> Result[CodeOptimizationContext, str]:
|
||||
code_to_optimize, contextual_dunder_methods = extract_code(
|
||||
|
|
@ -687,7 +687,7 @@ class Optimizer:
|
|||
optimizable_methods = [
|
||||
FunctionToOptimize(
|
||||
df.qualified_name.split(".")[-1],
|
||||
df.file_path,
|
||||
Path(df.file_path),
|
||||
[FunctionParent(df.qualified_name.split(".")[0], "ClassDef")],
|
||||
None,
|
||||
None,
|
||||
|
|
@ -730,18 +730,14 @@ class Optimizer:
|
|||
@staticmethod
|
||||
def cleanup_leftover_test_return_values() -> None:
|
||||
# remove leftovers from previous run
|
||||
pathlib.Path(get_run_tmp_file("test_return_values_0.bin")).unlink(
|
||||
missing_ok=True,
|
||||
)
|
||||
pathlib.Path(get_run_tmp_file("test_return_values_0.sqlite")).unlink(
|
||||
missing_ok=True,
|
||||
)
|
||||
get_run_tmp_file(Path("test_return_values_0.bin")).unlink(missing_ok=True)
|
||||
get_run_tmp_file(Path("test_return_values_0.sqlite")).unlink(missing_ok=True)
|
||||
|
||||
def instrument_existing_tests(
|
||||
self,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
function_to_tests: dict[str, list[TestsInFile]],
|
||||
) -> set[str]:
|
||||
) -> set[Path]:
|
||||
relevant_test_files_count = 0
|
||||
unique_instrumented_test_files = set()
|
||||
|
||||
|
|
@ -767,10 +763,8 @@ class Optimizer:
|
|||
)
|
||||
if not success:
|
||||
continue
|
||||
new_test_path = (
|
||||
f"{os.path.splitext(test_file)[0]}__perfinstrumented{os.path.splitext(test_file)[1]}"
|
||||
)
|
||||
with pathlib.Path(new_test_path).open("w", encoding="utf8") as f:
|
||||
new_test_path = Path(test_file).with_suffix(f"__perfinstrumented{Path(test_file).suffix}")
|
||||
with new_test_path.open("w", encoding="utf8") as f:
|
||||
f.write(injected_test)
|
||||
unique_instrumented_test_files.add(new_test_path)
|
||||
if not self.test_files.get_by_original_file_path(test_file):
|
||||
|
|
@ -793,7 +787,7 @@ class Optimizer:
|
|||
code_to_optimize_with_helpers: str,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
helper_functions: list[FunctionSource],
|
||||
module_path: str,
|
||||
module_path: Path,
|
||||
function_trace_id: str,
|
||||
run_experiment: bool = False,
|
||||
) -> Result[tuple[GeneratedTestsList, OptimizationSet], str]:
|
||||
|
|
@ -995,12 +989,12 @@ class Optimizer:
|
|||
|
||||
first_test_types = []
|
||||
first_test_functions = []
|
||||
pathlib.Path(
|
||||
get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin"),
|
||||
).unlink(missing_ok=True)
|
||||
pathlib.Path(
|
||||
get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite"),
|
||||
).unlink(missing_ok=True)
|
||||
Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin")).unlink(
|
||||
missing_ok=True,
|
||||
)
|
||||
Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(
|
||||
missing_ok=True,
|
||||
)
|
||||
|
||||
for test_file in instrumented_unittests_created_for_function:
|
||||
relevant_tests_in_file = [
|
||||
|
|
@ -1069,12 +1063,10 @@ class Optimizer:
|
|||
)
|
||||
if best_runtime_until_now is None or total_candidate_timing < best_runtime_until_now:
|
||||
best_test_results = candidate_results
|
||||
pathlib.Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin")).unlink(
|
||||
Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin")).unlink(
|
||||
missing_ok=True,
|
||||
)
|
||||
pathlib.Path(
|
||||
get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite"),
|
||||
).unlink(
|
||||
Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(
|
||||
missing_ok=True,
|
||||
)
|
||||
if not equal_results:
|
||||
|
|
@ -1142,7 +1134,7 @@ class Optimizer:
|
|||
source_code_being_tested: str,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
helper_function_names: list[str],
|
||||
module_path: str,
|
||||
module_path: Path,
|
||||
function_trace_id: str,
|
||||
) -> GeneratedTestsList | None:
|
||||
futures = [
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.machinery
|
||||
import io
|
||||
import json
|
||||
|
|
@ -23,20 +25,21 @@ import time
|
|||
from collections import defaultdict
|
||||
from copy import copy
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from types import FrameType
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, ClassVar, List, Optional
|
||||
|
||||
import dill
|
||||
import isort
|
||||
|
||||
from codeflash.cli_cmds.cli import project_root_from_module_root
|
||||
from codeflash.cli_cmds.console import console
|
||||
from codeflash.code_utils.code_utils import module_name_from_file_path
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
from codeflash.discovery.functions_to_optimize import filter_files_optimized
|
||||
from codeflash.tracing.replay_test import create_trace_replay_test
|
||||
from codeflash.tracing.tracing_utils import FunctionModules
|
||||
from codeflash.verification.verification_utils import get_test_file_path
|
||||
from codeflash.cli_cmds.console import console
|
||||
|
||||
|
||||
class Tracer:
|
||||
|
|
@ -49,7 +52,7 @@ class Tracer:
|
|||
output: str = "codeflash.trace",
|
||||
functions: Optional[List[str]] = None,
|
||||
disable: bool = False,
|
||||
config_file_path: Optional[str] = None,
|
||||
config_file_path: Path | None = None,
|
||||
max_function_count: int = 256,
|
||||
timeout: Optional[int] = None, # seconds
|
||||
) -> None:
|
||||
|
|
@ -77,13 +80,14 @@ class Tracer:
|
|||
self.disable = True
|
||||
return
|
||||
self.con = None
|
||||
self.output_file = os.path.abspath(output)
|
||||
self.output_file = Path(output).resolve()
|
||||
self.functions = functions
|
||||
self.function_modules: List[FunctionModules] = []
|
||||
self.function_count = defaultdict(int)
|
||||
self.current_file_path = Path(__file__).resolve()
|
||||
self.ignored_qualified_functions = {
|
||||
f"{os.path.realpath(__file__)}:Tracer:__exit__",
|
||||
f"{os.path.realpath(__file__)}:Tracer:__enter__",
|
||||
f"{self.current_file_path}:Tracer:__exit__",
|
||||
f"{self.current_file_path}:Tracer:__enter__",
|
||||
}
|
||||
self.max_function_count = max_function_count
|
||||
self.config, found_config_path = parse_config_file(config_file_path)
|
||||
|
|
@ -99,13 +103,9 @@ class Tracer:
|
|||
"<lambda>",
|
||||
"<module>",
|
||||
}
|
||||
self.file_being_called_from: str = str(
|
||||
os.path.basename(
|
||||
os.path.realpath(sys._getframe().f_back.f_code.co_filename),
|
||||
).replace(
|
||||
".",
|
||||
"_",
|
||||
),
|
||||
|
||||
self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(
|
||||
".", "_"
|
||||
)
|
||||
|
||||
assert timeout is None or timeout > 0, "Timeout should be greater than 0"
|
||||
|
|
@ -175,7 +175,7 @@ class Tracer:
|
|||
cur.execute(
|
||||
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
os.path.realpath(func[0]),
|
||||
Path(func[0]).resolve(),
|
||||
func[1],
|
||||
func[2],
|
||||
func[3],
|
||||
|
|
@ -245,7 +245,7 @@ class Tracer:
|
|||
|
||||
if code.co_name in self.ignored_functions:
|
||||
return
|
||||
if not os.path.exists(file_name):
|
||||
if not Path(file_name).exists():
|
||||
return
|
||||
if self.functions:
|
||||
if code.co_name not in self.functions:
|
||||
|
|
@ -264,7 +264,7 @@ class Tracer:
|
|||
except:
|
||||
# someone can override the getattr method and raise an exception. I'm looking at you wrapt
|
||||
return
|
||||
file_name = os.path.realpath(file_name)
|
||||
file_name = Path(file_name).resolve()
|
||||
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"
|
||||
if function_qualified_name in self.ignored_qualified_functions:
|
||||
return
|
||||
|
|
@ -311,7 +311,7 @@ class Tracer:
|
|||
# We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
|
||||
# directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
|
||||
# leaks, bad references or side-effects when unpickling.
|
||||
arguments = {k: v for k, v in arguments.items()}
|
||||
arguments = dict(arguments.items())
|
||||
if class_name and code.co_name == "__init__":
|
||||
del arguments["self"]
|
||||
local_vars = pickle.dumps(
|
||||
|
|
@ -463,7 +463,7 @@ class Tracer:
|
|||
|
||||
return 1
|
||||
|
||||
dispatch = {
|
||||
dispatch: ClassVar[dict[str, callable]] = {
|
||||
"call": trace_dispatch_call,
|
||||
"exception": trace_dispatch_exception,
|
||||
"return": trace_dispatch_return,
|
||||
|
|
@ -648,7 +648,7 @@ def main():
|
|||
# The script that we're profiling may chdir, so capture the absolute path
|
||||
# to the output file at startup.
|
||||
if args.outfile is not None:
|
||||
args.outfile = os.path.abspath(args.outfile)
|
||||
args.outfile = Path(args.outfile).resolve()
|
||||
|
||||
if len(unknown_args) > 0:
|
||||
if args.module:
|
||||
|
|
@ -661,7 +661,7 @@ def main():
|
|||
}
|
||||
else:
|
||||
progname = unknown_args[0]
|
||||
sys.path.insert(0, os.path.dirname(progname))
|
||||
sys.path.insert(0, str(Path(progname).parent))
|
||||
with io.open_code(progname) as fp:
|
||||
code = compile(fp.read(), progname, "exec")
|
||||
spec = importlib.machinery.ModuleSpec(name="__main__", loader=None, origin=progname)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
|
||||
|
||||
|
||||
def test_add_needed_imports_from_module0() -> None:
|
||||
|
|
@ -46,9 +48,9 @@ class Source:
|
|||
expected = """def heyjude() -> None:
|
||||
print("Hey Jude, don't make it bad")
|
||||
"""
|
||||
src_path = "/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py"
|
||||
dst_path = "/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py"
|
||||
project_root = "/home/roger/repos/codeflash"
|
||||
src_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
|
||||
dst_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
|
||||
project_root = Path("/home/roger/repos/codeflash")
|
||||
new_module = add_needed_imports_from_module(
|
||||
src_module,
|
||||
dst_module,
|
||||
|
|
@ -120,9 +122,9 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
|
|||
# The name is defined inside the function or is the function itself
|
||||
return f".{function_name}." in subname or f".{function_name}" == subname
|
||||
'''
|
||||
src_path = "/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py"
|
||||
dst_path = "/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py"
|
||||
project_root = "/home/roger/repos/codeflash"
|
||||
src_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
|
||||
dst_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
|
||||
project_root = Path("/home/roger/repos/codeflash")
|
||||
new_module = add_needed_imports_from_module(
|
||||
src_module,
|
||||
dst_module,
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class JediDefinition:
|
|||
|
||||
@dataclasses.dataclass
|
||||
class FakeFunctionSource:
|
||||
file_path: str
|
||||
file_path: Path
|
||||
qualified_name: str
|
||||
fully_qualified_name: str
|
||||
only_function_name: str
|
||||
|
|
@ -81,11 +81,11 @@ print("Hello world")
|
|||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -145,11 +145,11 @@ print("Hello world")
|
|||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -206,11 +206,11 @@ print("Salut monde")
|
|||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -270,11 +270,11 @@ print("Salut monde")
|
|||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -326,11 +326,11 @@ def supersort(doink):
|
|||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -402,11 +402,11 @@ print("Not cool")
|
|||
source_code=original_code_main,
|
||||
function_names=["other_function"],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=[("other_function", []), ("yet_another_function", []), ("blob", [])],
|
||||
contextual_functions=set(),
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_main_code == expected_main
|
||||
|
||||
|
|
@ -414,11 +414,11 @@ print("Not cool")
|
|||
source_code=original_code_helper,
|
||||
function_names=["blob"],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=[],
|
||||
contextual_functions=set(),
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_helper_code == expected_helper
|
||||
|
||||
|
|
@ -619,11 +619,11 @@ class CacheConfig(BaseConfig):
|
|||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -694,11 +694,11 @@ def test_test_libcst_code_replacement8() -> None:
|
|||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -757,11 +757,11 @@ print("Hello world")
|
|||
source_code=original_code,
|
||||
function_names=[function_name],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == expected
|
||||
|
||||
|
|
@ -817,17 +817,16 @@ class MainClass:
|
|||
)
|
||||
func_top_optimize = FunctionToOptimize(
|
||||
function_name="main_method",
|
||||
file_path=str(file_path),
|
||||
file_path=file_path,
|
||||
parents=[FunctionParent("MainClass", "ClassDef")],
|
||||
)
|
||||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
code_context = opt.get_code_optimization_context(
|
||||
function_to_optimize=func_top_optimize,
|
||||
project_root=str(file_path.parent),
|
||||
original_source_code=original_code,
|
||||
).unwrap()
|
||||
assert code_context.code_to_optimize_with_helpers == get_code_output
|
||||
original_code = file_path.read_text()
|
||||
code_context = opt.get_code_optimization_context(
|
||||
function_to_optimize=func_top_optimize,
|
||||
project_root=file_path.parent,
|
||||
original_source_code=original_code,
|
||||
).unwrap()
|
||||
assert code_context.code_to_optimize_with_helpers == get_code_output
|
||||
|
||||
|
||||
def test_code_replacement11() -> None:
|
||||
|
|
@ -946,11 +945,11 @@ def test_test_libcst_code_replacement13() -> None:
|
|||
source_code=original_code,
|
||||
function_names=function_names,
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).resolve().parent.resolve()),
|
||||
project_root_path=Path(__file__).resolve().parent.resolve(),
|
||||
)
|
||||
assert new_code == original_code
|
||||
|
||||
|
|
@ -1143,7 +1142,9 @@ class TestResults(BaseModel):
|
|||
|
||||
helper_functions = [
|
||||
FakeFunctionSource(
|
||||
file_path="/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/codeflash/verification/test_results.py",
|
||||
file_path=Path(
|
||||
"/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/codeflash/verification/test_results.py"
|
||||
),
|
||||
qualified_name="TestType",
|
||||
fully_qualified_name="codeflash.verification.test_results.TestType",
|
||||
only_function_name="TestType",
|
||||
|
|
@ -1156,11 +1157,11 @@ class TestResults(BaseModel):
|
|||
source_code=original_code,
|
||||
function_names=["TestResults.get_test_pass_fail_report_by_type"],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str(Path(__file__).resolve()),
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=Path(__file__).resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).parent.resolve()),
|
||||
project_root_path=Path(__file__).parent.resolve(),
|
||||
)
|
||||
|
||||
helper_functions_by_module_abspath = defaultdict(set)
|
||||
|
|
@ -1177,11 +1178,11 @@ class TestResults(BaseModel):
|
|||
source_code=new_code,
|
||||
function_names=list(qualified_names),
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=module_abspath,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).parent.resolve()),
|
||||
project_root_path=Path(__file__).parent.resolve(),
|
||||
)
|
||||
|
||||
assert (
|
||||
|
|
@ -1361,12 +1362,16 @@ def cosine_similarity_top_k(
|
|||
|
||||
return ret_idxs, scores
|
||||
'''
|
||||
preexisting_objects = [("cosine_similarity_top_k", []), ("Matrix", []), ("cosine_similarity", [])]
|
||||
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
|
||||
("cosine_similarity_top_k", []),
|
||||
("Matrix", []),
|
||||
("cosine_similarity", []),
|
||||
]
|
||||
|
||||
contextual_functions = set()
|
||||
contextual_functions: set[tuple[str, str]] = set()
|
||||
helper_functions = [
|
||||
FakeFunctionSource(
|
||||
file_path=str((Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve()),
|
||||
file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(),
|
||||
qualified_name="Matrix",
|
||||
fully_qualified_name="code_to_optimize.math_utils.Matrix",
|
||||
only_function_name="Matrix",
|
||||
|
|
@ -1374,7 +1379,7 @@ def cosine_similarity_top_k(
|
|||
jedi_definition=JediDefinition(type="class"),
|
||||
),
|
||||
FakeFunctionSource(
|
||||
file_path=str((Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve()),
|
||||
file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(),
|
||||
qualified_name="cosine_similarity",
|
||||
fully_qualified_name="code_to_optimize.math_utils.cosine_similarity",
|
||||
only_function_name="cosine_similarity",
|
||||
|
|
@ -1387,11 +1392,11 @@ def cosine_similarity_top_k(
|
|||
source_code=original_code,
|
||||
function_names=["cosine_similarity_top_k"],
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
module_abspath=str((Path(__file__).parent / "code_to_optimize").resolve()),
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=(Path(__file__).parent / "code_to_optimize").resolve(),
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).parent.parent.resolve()),
|
||||
project_root_path=Path(__file__).parent.parent.resolve(),
|
||||
)
|
||||
assert (
|
||||
new_code
|
||||
|
|
@ -1452,11 +1457,11 @@ def cosine_similarity_top_k(
|
|||
source_code=new_code,
|
||||
function_names=list(qualified_names),
|
||||
optimized_code=optim_code,
|
||||
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
|
||||
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
|
||||
module_abspath=module_abspath,
|
||||
preexisting_objects=preexisting_objects,
|
||||
contextual_functions=contextual_functions,
|
||||
project_root_path=str(Path(__file__).parent.parent.resolve()),
|
||||
project_root_path=Path(__file__).parent.parent.resolve(),
|
||||
)
|
||||
|
||||
assert (
|
||||
|
|
|
|||
92
tests/test_code_utils.py
Normal file
92
tests/test_code_utils.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.code_utils import get_imports_from_file, module_name_from_file_path
|
||||
|
||||
|
||||
# tests for module_name_from_file_path
|
||||
def test_module_name_from_file_path() -> None:
|
||||
project_root_path = Path("/Users/codeflashuser/PycharmProjects/codeflash")
|
||||
file_path = project_root_path / "cli/codeflash/code_utils/code_utils.py"
|
||||
|
||||
module_name = module_name_from_file_path(file_path, project_root_path)
|
||||
assert module_name == "cli.codeflash.code_utils.code_utils"
|
||||
|
||||
|
||||
def test_module_name_from_file_path_with_subdirectory() -> None:
|
||||
project_root_path = Path("/Users/codeflashuser/PycharmProjects/codeflash")
|
||||
file_path = project_root_path / "cli/codeflash/code_utils/subdir/code_utils.py"
|
||||
|
||||
module_name = module_name_from_file_path(file_path, project_root_path)
|
||||
assert module_name == "cli.codeflash.code_utils.subdir.code_utils"
|
||||
|
||||
|
||||
def test_module_name_from_file_path_with_different_root() -> None:
|
||||
project_root_path = Path("/Users/codeflashuser/PycharmProjects")
|
||||
file_path = project_root_path / "codeflash/cli/codeflash/code_utils/code_utils.py"
|
||||
|
||||
module_name = module_name_from_file_path(file_path, project_root_path)
|
||||
assert module_name == "codeflash.cli.codeflash.code_utils.code_utils"
|
||||
|
||||
|
||||
def test_module_name_from_file_path_with_root_as_file() -> None:
|
||||
project_root_path = Path("/Users/codeflashuser/PycharmProjects/codeflash/cli/codeflash/code_utils")
|
||||
file_path = project_root_path / "code_utils.py"
|
||||
|
||||
module_name = module_name_from_file_path(file_path, project_root_path)
|
||||
assert module_name == "code_utils"
|
||||
|
||||
|
||||
# def test_get_imports_from_file_with_file_path(tmp_path: Path):
|
||||
# test_file = tmp_path / "test_file.py"
|
||||
# test_file.write_text("import os\nfrom sys import path\n")
|
||||
|
||||
# imports = get_imports_from_file(file_path=test_file)
|
||||
# assert len(imports) == 2
|
||||
# assert isinstance(imports[0], ast.Import)
|
||||
# assert isinstance(imports[1], ast.ImportFrom)
|
||||
# assert imports[0].names[0].name == "os"
|
||||
# assert imports[1].module == "sys"
|
||||
# assert imports[1].names[0].name == "path"
|
||||
|
||||
|
||||
# def test_get_imports_from_file_with_file_string():
|
||||
# file_string = "import os\nfrom sys import path\n"
|
||||
|
||||
# imports = get_imports_from_file(file_string=file_string)
|
||||
# assert len(imports) == 2
|
||||
# assert isinstance(imports[0], ast.Import)
|
||||
# assert isinstance(imports[1], ast.ImportFrom)
|
||||
# assert imports[0].names[0].name == "os"
|
||||
# assert imports[1].module == "sys"
|
||||
# assert imports[1].names[0].name == "path"
|
||||
|
||||
|
||||
# def test_get_imports_from_file_with_file_ast():
|
||||
# file_string = "import os\nfrom sys import path\n"
|
||||
# file_ast = ast.parse(file_string)
|
||||
|
||||
# imports = get_imports_from_file(file_ast=file_ast)
|
||||
# assert len(imports) == 2
|
||||
# assert isinstance(imports[0], ast.Import)
|
||||
# assert isinstance(imports[1], ast.ImportFrom)
|
||||
# assert imports[0].names[0].name == "os"
|
||||
# assert imports[1].module == "sys"
|
||||
# assert imports[1].names[0].name == "path"
|
||||
|
||||
|
||||
# def test_get_imports_from_file_with_syntax_error(caplog):
|
||||
# file_string = "import os\nfrom sys import path\ninvalid syntax"
|
||||
|
||||
# imports = get_imports_from_file(file_string=file_string)
|
||||
# assert len(imports) == 0
|
||||
# assert "Syntax error in code" in caplog.text
|
||||
|
||||
|
||||
# def test_get_imports_from_file_with_no_input():
|
||||
# with pytest.raises(
|
||||
# AssertionError, match="Must provide exactly one of file_path, file_string, or file_ast"
|
||||
# ):
|
||||
# get_imports_from_file()
|
||||
|
|
@ -1,7 +1,9 @@
|
|||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
from codeflash.code_utils.formatter import format_code, sort_imports
|
||||
|
||||
|
|
@ -34,12 +36,13 @@ def test_sort_imports_without_formatting():
|
|||
with tempfile.NamedTemporaryFile() as tmp:
|
||||
tmp.write(b"import sys\nimport unittest\nimport os\n")
|
||||
tmp.flush()
|
||||
tmp_path = tmp.name
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
new_code = format_code(
|
||||
formatter_cmds=["disabled"],
|
||||
path=tmp_path,
|
||||
)
|
||||
assert new_code is not None
|
||||
new_code = sort_imports(new_code)
|
||||
assert new_code == "import os\nimport sys\nimport unittest\n"
|
||||
|
||||
|
|
@ -108,7 +111,7 @@ ignore-paths = []
|
|||
with tempfile.NamedTemporaryFile(suffix=".toml", delete=False) as tmp:
|
||||
tmp.write(config_data.encode())
|
||||
tmp.flush()
|
||||
tmp_path = tmp.name
|
||||
tmp_path = Path(tmp.name)
|
||||
|
||||
try:
|
||||
config, _ = parse_config_file(tmp_path)
|
||||
|
|
@ -135,7 +138,7 @@ def foo():
|
|||
|
||||
actual = format_code(
|
||||
formatter_cmds=["black $file"],
|
||||
path=tmp_path,
|
||||
path=Path(tmp_path),
|
||||
)
|
||||
assert actual == expected
|
||||
|
||||
|
|
@ -160,7 +163,7 @@ def foo():
|
|||
|
||||
actual = format_code(
|
||||
formatter_cmds=["black $file"],
|
||||
path=tmp_path,
|
||||
path=Path(tmp_path),
|
||||
)
|
||||
assert actual == expected
|
||||
|
||||
|
|
@ -189,6 +192,6 @@ def foo():
|
|||
|
||||
actual = format_code(
|
||||
formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"],
|
||||
path=tmp_path,
|
||||
path=Path(tmp_path),
|
||||
)
|
||||
assert actual == expected
|
||||
|
|
|
|||
Loading…
Reference in a new issue