This commit is contained in:
Kevin Turcios 2024-10-12 17:29:15 -05:00
parent 27b9fd8a4d
commit b45cc87270
19 changed files with 473 additions and 391 deletions

3
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,3 @@
{
"python.analysis.typeCheckingMode": "basic"
}

View file

@ -1,13 +1,14 @@
import logging
import os
import sys
from argparse import SUPPRESS, ArgumentParser, Namespace
from pathlib import Path
import git
from codeflash.cli_cmds import logging_config
from codeflash.cli_cmds.cli_common import apologize_and_exit
from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions
from codeflash.cli_cmds.console import logger
from codeflash.code_utils import env_utils
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.git_utils import (
@ -17,7 +18,6 @@ from codeflash.code_utils.git_utils import (
get_repo_owner_and_name,
)
from codeflash.code_utils.github_utils import get_github_secrets_page_url, require_github_app_or_exit
from codeflash.cli_cmds.console import logger
from codeflash.version import __version__ as version
@ -104,18 +104,18 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace:
logger.error("If you specify a --function, you must specify the --file it is in")
sys.exit(1)
if args.file:
if not os.path.exists(args.file):
if not Path(args.file).exists():
logger.error(f"File {args.file} does not exist")
sys.exit(1)
args.file = os.path.realpath(args.file)
args.file = Path(args.file).resolve()
if not args.no_pr:
owner, repo = get_repo_owner_and_name()
require_github_app_or_exit(owner, repo)
if args.replay_test:
if not os.path.isfile(args.replay_test):
if not Path(args.replay_test).is_file():
logger.error(f"Replay test file {args.replay_test} does not exist")
sys.exit(1)
args.replay_test = os.path.realpath(args.replay_test)
args.replay_test = Path(args.replay_test).resolve()
return args
@ -142,11 +142,11 @@ def process_pyproject_config(args: Namespace) -> Namespace:
or not hasattr(args, key.replace("-", "_"))
):
setattr(args, key.replace("-", "_"), pyproject_config[key])
assert args.module_root is not None and os.path.isdir(
args.module_root,
assert (
args.module_root is not None and Path(args.module_root).is_dir()
), f"--module-root {args.module_root} must be a valid directory"
assert args.tests_root is not None and os.path.isdir(
args.tests_root,
assert (
args.tests_root is not None and Path(args.tests_root).is_dir()
), f"--tests-root {args.tests_root} must be a valid directory"
assert not (env_utils.get_pr_number() is not None and not env_utils.ensure_codeflash_api_key()), (
@ -160,24 +160,23 @@ def process_pyproject_config(args: Namespace) -> Namespace:
if hasattr(args, "ignore_paths") and args.ignore_paths is not None:
normalized_ignore_paths = []
for path in args.ignore_paths:
assert os.path.exists(
path,
), f"ignore-paths config must be a valid path. Path {path} does not exist"
normalized_ignore_paths.append(os.path.realpath(path))
path_obj = Path(path)
assert path_obj.exists(), f"ignore-paths config must be a valid path. Path {path} does not exist"
normalized_ignore_paths.append(path_obj.resolve())
args.ignore_paths = normalized_ignore_paths
# Project root path is one level above the specified directory, because that's where the module can be imported from
args.module_root = os.path.realpath(args.module_root)
args.module_root = Path(args.module_root).resolve()
# If module-root is "." then all imports are relatives to it.
# in this case, the ".." becomes outside project scope, causing issues with un-importable paths
args.project_root = project_root_from_module_root(args.module_root, pyproject_file_path)
args.tests_root = os.path.realpath(args.tests_root)
args.tests_root = Path(args.tests_root).resolve()
return handle_optimize_all_arg_parsing(args)
def project_root_from_module_root(module_root: str, pyproject_file_path: str) -> str:
if os.path.dirname(pyproject_file_path) == module_root:
def project_root_from_module_root(module_root: Path, pyproject_file_path: Path) -> Path:
if pyproject_file_path.parent == module_root:
return module_root
return os.path.realpath(os.path.join(module_root, ".."))
return module_root.parent.resolve()
def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
@ -203,5 +202,5 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
# The default behavior of --all is to optimize everything in args.module_root
args.all = args.module_root
else:
args.all = os.path.realpath(args.all)
args.all = Path(args.all).resolve()
return args

View file

@ -2,11 +2,11 @@ from __future__ import annotations
import ast
import os
import pathlib
import re
import subprocess
import sys
from argparse import Namespace
from pathlib import Path
from typing import Optional
import click
@ -107,7 +107,7 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
def collect_setup_info() -> SetupInfo:
curdir = os.getcwd()
curdir = Path.cwd()
# Check if the cwd is writable
if not os.access(curdir, os.W_OK):
click.echo(
@ -170,20 +170,22 @@ def collect_setup_info() -> SetupInfo:
)
if tests_root_answer == create_for_me_option:
tests_root = os.path.join(curdir, default_tests_subdir)
os.mkdir(tests_root)
click.echo(f"✅ Created directory {tests_root}{os.pathsep}{LF}")
tests_root = Path(curdir) / default_tests_subdir
tests_root.mkdir()
click.echo(f"✅ Created directory {tests_root}{os.path.sep}{LF}")
elif tests_root_answer == custom_dir_option:
custom_tests_root_answer = inquirer_wrapper_path(
"path",
message=f"Enter the path to your tests directory inside {os.path.abspath('.') + os.path.sep} ",
message=f"Enter the path to your tests directory inside {Path(curdir).resolve()}{os.path.sep} ",
path_type=inquirer.Path.DIRECTORY,
exists=True,
)
tests_root = custom_tests_root_answer["path"] if custom_tests_root_answer else apologize_and_exit()
tests_root = (
Path(custom_tests_root_answer["path"]) if custom_tests_root_answer else apologize_and_exit()
)
else:
tests_root = tests_root_answer
tests_root = os.path.relpath(tests_root, curdir)
tests_root = Path(tests_root_answer)
tests_root = tests_root.relative_to(curdir)
ph("cli-tests-root-provided")
# Autodiscover test framework
@ -224,7 +226,7 @@ def collect_setup_info() -> SetupInfo:
)
def detect_test_framework(curdir, tests_root) -> Optional[str]:
def detect_test_framework(curdir: Path, tests_root: Path) -> Optional[str]:
test_framework = None
pytest_files = ["pytest.ini", "pyproject.toml", "tox.ini", "setup.cfg"]
pytest_config_patterns = {
@ -234,9 +236,9 @@ def detect_test_framework(curdir, tests_root) -> Optional[str]:
"setup.cfg": "[tool:pytest]",
}
for pytest_file in pytest_files:
file_path = os.path.join(curdir, pytest_file)
if os.path.exists(file_path):
with open(file_path, encoding="utf8") as file:
file_path = curdir / pytest_file
if file_path.exists():
with file_path.open(encoding="utf8") as file:
contents = file.read()
if pytest_config_patterns[pytest_file] in contents:
test_framework = "pytest"
@ -244,9 +246,9 @@ def detect_test_framework(curdir, tests_root) -> Optional[str]:
test_framework = "pytest"
else:
# Check if any python files contain a class that inherits from unittest.TestCase
for filename in os.listdir(tests_root):
if filename.endswith(".py"):
with open(os.path.join(tests_root, filename), encoding="utf8") as file:
for filename in tests_root.iterdir():
if filename.suffix == ".py":
with filename.open(encoding="utf8") as file:
contents = file.read()
try:
node = ast.parse(contents)
@ -271,14 +273,13 @@ def detect_test_framework(curdir, tests_root) -> Optional[str]:
def check_for_toml_or_setup_file() -> Optional[str]:
click.echo()
click.echo("Checking for pyproject.toml or setup.py ...\r", nl=False)
curdir = os.getcwd()
pyproject_toml_path = os.path.join(curdir, "pyproject.toml")
setup_py_path = os.path.join(curdir, "setup.py")
curdir = Path.cwd()
pyproject_toml_path = curdir / "pyproject.toml"
setup_py_path = curdir / "setup.py"
project_name = None
if os.path.exists(pyproject_toml_path):
if pyproject_toml_path.exists():
try:
with open(pyproject_toml_path, encoding="utf8") as f:
pyproject_toml_content = f.read()
pyproject_toml_content = pyproject_toml_path.read_text(encoding="utf8")
project_name = tomlkit.parse(pyproject_toml_content)["tool"]["poetry"]["name"]
click.echo(f"✅ I found a pyproject.toml for your project {project_name}.")
ph("cli-pyproject-toml-found-name")
@ -286,9 +287,8 @@ def check_for_toml_or_setup_file() -> Optional[str]:
click.echo("✅ I found a pyproject.toml for your project.")
ph("cli-pyproject-toml-found")
else:
if os.path.exists(setup_py_path):
with open(setup_py_path, encoding="utf8") as f:
setup_py_content = f.read()
if setup_py_path.exists():
setup_py_content = setup_py_path.read_text(encoding="utf8")
project_name_match = re.search(
r"setup\s*\([^)]*?name\s*=\s*['\"](.*?)['\"]",
setup_py_content,
@ -321,11 +321,10 @@ def check_for_toml_or_setup_file() -> Optional[str]:
new_pyproject_toml = tomlkit.document()
new_pyproject_toml["tool"] = {"codeflash": {}}
try:
with open(pyproject_toml_path, "w", encoding="utf8") as pyproject_file:
pyproject_file.write(tomlkit.dumps(new_pyproject_toml))
pyproject_toml_path.write_text(tomlkit.dumps(new_pyproject_toml), encoding="utf8")
# Check if the pyproject.toml file was created
if os.path.exists(pyproject_toml_path):
if pyproject_toml_path.exists():
click.echo(
f"✅ Created a pyproject.toml file at {pyproject_toml_path}",
)
@ -356,9 +355,9 @@ def install_github_actions() -> None:
owner, repo_name = get_repo_owner_and_name(repo)
require_github_app_or_exit(owner, repo_name)
git_root = repo.git.rev_parse("--show-toplevel")
workflows_path = os.path.join(git_root, ".github", "workflows")
optimize_yaml_path = os.path.join(workflows_path, "codeflash-optimize.yaml")
git_root = Path(repo.git.rev_parse("--show-toplevel"))
workflows_path = git_root / ".github" / "workflows"
optimize_yaml_path = workflows_path / "codeflash-optimize.yaml"
confirm_creation_yes = inquirer_wrapper(
inquirer.confirm,
@ -373,7 +372,7 @@ def install_github_actions() -> None:
click.echo("⏩️ Exiting workflow creation.")
ph("cli-github-workflow-skipped")
apologize_and_exit()
os.makedirs(workflows_path, exist_ok=True)
workflows_path.mkdir(parents=True, exist_ok=True)
from importlib.resources import files
py_version = sys.version_info
@ -387,13 +386,13 @@ def install_github_actions() -> None:
"{{ python_version }}",
python_version_string,
)
with open(optimize_yaml_path, "w", encoding="utf8") as optimize_yml_file:
with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file:
optimize_yml_file.write(optimize_yml_content)
click.echo(f"✅ Created {optimize_yaml_path}{LF}")
click.prompt(
f"Next, you'll need to add your CODEFLASH_API_KEY as a secret to your GitHub repo.{LF}"
+ f"Press Enter to open your repo's secrets page at {get_github_secrets_page_url(repo)} ...{LF}"
+ f"Then, click 'New repository secret' to add your api key with the variable name CODEFLASH_API_KEY.{LF}",
f"Press Enter to open your repo's secrets page at {get_github_secrets_page_url(repo)} ...{LF}"
f"Then, click 'New repository secret' to add your api key with the variable name CODEFLASH_API_KEY.{LF}",
default="",
type=click.STRING,
prompt_suffix="",
@ -402,7 +401,7 @@ def install_github_actions() -> None:
click.launch(get_github_secrets_page_url(repo))
click.echo(
"🐙 I opened your Github secrets page! Note: if you see a 404, you probably don't have access to this "
+ "repo's secrets; ask a repo admin to add it for you, or (not super recommended) you can temporarily "
"repo's secrets; ask a repo admin to add it for you, or (not super recommended) you can temporarily "
f"hard-code your api key into the workflow file.{LF}",
)
click.pause()
@ -410,7 +409,7 @@ def install_github_actions() -> None:
click.prompt(
f"Finally, for the workflow to work, you'll need to edit the workflow file to install the right "
f"Python version and any project dependencies.{LF}"
+ f"Press Enter to open {optimize_yaml_path} in your editor.{LF}",
f"Press Enter to open {optimize_yaml_path} in your editor.{LF}",
default="",
type=click.STRING,
prompt_suffix="",
@ -419,7 +418,7 @@ def install_github_actions() -> None:
click.launch(optimize_yaml_path)
click.echo(
"📝 I opened the workflow file in your editor! You'll need to edit the steps that install the right Python "
+ f"version and any project dependencies. See the comments in the file for more details.{LF}",
f"version and any project dependencies. See the comments in the file for more details.{LF}",
)
click.pause()
click.echo()
@ -434,9 +433,9 @@ def install_github_actions() -> None:
# Create or update the pyproject.toml file with the Codeflash dependency & configuration
def configure_pyproject_toml(setup_info: SetupInfo) -> None:
toml_path = os.path.join(os.getcwd(), "pyproject.toml")
toml_path = Path.cwd() / "pyproject.toml"
try:
with open(toml_path, encoding="utf8") as pyproject_file:
with toml_path.open(encoding="utf8") as pyproject_file:
pyproject_data = tomlkit.parse(pyproject_file.read())
except FileNotFoundError:
click.echo(
@ -470,7 +469,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
pyproject_data["tool"] = tool_section
click.echo("Writing Codeflash configuration ...\r", nl=False)
with open(toml_path, "w", encoding="utf8") as pyproject_file:
with toml_path.open("w", encoding="utf8") as pyproject_file:
pyproject_file.write(tomlkit.dumps(pyproject_data))
click.echo(f"✅ Added Codeflash configuration to {toml_path}")
click.echo()
@ -488,8 +487,8 @@ def install_github_app() -> None:
else:
click.prompt(
f"Finally, you'll need install the Codeflash GitHub app by choosing the repository you want to install Codeflash on.{LF}"
+ f"I will attempt to open the github app page - https://github.com/apps/codeflash-ai/installations/select_target {LF}"
+ f"Press Enter to open the page to let you install the app ...{LF}",
f"I will attempt to open the github app page - https://github.com/apps/codeflash-ai/installations/select_target {LF}"
f"Press Enter to open the page to let you install the app ...{LF}",
default="",
type=click.STRING,
prompt_suffix="",
@ -601,7 +600,7 @@ def enter_api_key_and_save_to_rc() -> None:
os.environ["CODEFLASH_API_KEY"] = api_key
def create_bubble_sort_file_and_test(args: Namespace) -> None:
def create_bubble_sort_file_and_test(args: Namespace) -> tuple[str, str]:
bubble_sort_content = """def sorter(arr):
for i in range(len(arr)):
for j in range(len(arr) - 1):
@ -613,10 +612,10 @@ def create_bubble_sort_file_and_test(args: Namespace) -> None:
"""
if args.test_framework == "unittest":
bubble_sort_test_content = f"""import unittest
from {os.path.basename(args.module_root)}.bubble_sort import sorter
from {Path(args.module_root).name}.bubble_sort import sorter
class TestBubbleSort(unittest.TestCase):
def test_sort(self):
class TestBubbleSort(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1, 0]
output = sorter(input)
self.assertEqual(output, [0, 1, 2, 3, 4, 5])
@ -628,9 +627,9 @@ class TestBubbleSort(unittest.TestCase):
input = list(reversed(range(100)))
output = sorter(input)
self.assertEqual(output, list(range(100)))
"""
"""
elif args.test_framework == "pytest":
bubble_sort_test_content = f"""from {os.path.basename(args.module_root)}.bubble_sort import sorter
bubble_sort_test_content = f"""from {Path(args.module_root).name}.bubble_sort import sorter
def test_sort():
input = [5, 4, 3, 2, 1, 0]
@ -648,15 +647,17 @@ def test_sort():
else:
click.echo(f"❌ Unsupported test framework: {args.test_framework}")
apologize_and_exit()
bubble_sort_path = os.path.join(args.module_root, "bubble_sort.py")
with open(bubble_sort_path, "w", encoding="utf8") as bubble_sort_file:
bubble_sort_file.write(bubble_sort_content)
bubble_sort_test_path = os.path.join(args.tests_root, "test_bubble_sort.py")
with open(bubble_sort_test_path, "w", encoding="utf8") as bubble_sort_test_file:
bubble_sort_test_file.write(bubble_sort_test_content)
bubble_sort_path = Path(args.module_root) / "bubble_sort.py"
bubble_sort_path.write_text(bubble_sort_content, encoding="utf8")
bubble_sort_test_path = Path(args.tests_root) / "test_bubble_sort.py"
bubble_sort_test_path.write_text(bubble_sort_test_content, encoding="utf8")
click.echo(f"✅ Created {bubble_sort_path}")
click.echo(f"✅ Created {bubble_sort_test_path}")
return bubble_sort_path, bubble_sort_test_path
return str(bubble_sort_path), str(bubble_sort_test_path)
def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test_path: str) -> None:

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import ast
from pathlib import Path
from typing import TYPE_CHECKING
import libcst as cst
@ -9,8 +10,8 @@ from libcst.codemod import CodemodContext
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
from libcst.helpers import calculate_module_and_package
from codeflash.discovery.functions_to_optimize import FunctionParent
from codeflash.cli_cmds.console import logger
from codeflash.discovery.functions_to_optimize import FunctionParent
if TYPE_CHECKING:
from libcst.helpers import ModuleNameAndPackage
@ -43,9 +44,9 @@ def delete___future___aliased_imports(module_code: str) -> str:
def add_needed_imports_from_module(
src_module_code: str,
dst_module_code: str,
src_path: str,
dst_path: str,
project_root: str,
src_path: Path,
dst_path: Path,
project_root: Path,
helper_functions: list[FunctionSource] | None = None,
) -> str:
"""Add all needed and used source module code imports to the destination module code, and return it."""
@ -57,13 +58,13 @@ def add_needed_imports_from_module(
dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path)
dst_context: CodemodContext = CodemodContext(
filename=src_path,
filename=src_path.name,
full_module_name=dst_module_and_package.name,
full_package_name=dst_module_and_package.package,
)
gatherer: GatherImportsVisitor = GatherImportsVisitor(
CodemodContext(
filename=src_path,
filename=src_path.name,
full_module_name=src_module_and_package.name,
full_package_name=src_module_and_package.package,
),
@ -129,7 +130,7 @@ def get_code(
):
return None, set()
file_path: str = functions_to_optimize[0].file_path
file_path: Path = functions_to_optimize[0].file_path
class_skeleton: set[tuple[int, int | None]] = set()
contextual_dunder_methods: set[tuple[str, str]] = set()
target_code: str = ""

View file

@ -1,7 +1,7 @@
from __future__ import annotations
import ast
from typing import IO
from typing import IO, TYPE_CHECKING
import libcst as cst
from libcst import FunctionDef
@ -9,6 +9,9 @@ from libcst import FunctionDef
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from codeflash.discovery.functions_to_optimize import FunctionParent
if TYPE_CHECKING:
from pathlib import Path
class OptimFunctionCollector(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
@ -200,11 +203,11 @@ def replace_functions_and_add_imports(
source_code: str,
function_names: list[str],
optimized_code: str,
file_path_of_module_with_function_to_optimize: str,
module_abspath: str,
file_path_of_module_with_function_to_optimize: Path,
module_abspath: Path,
preexisting_objects: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
project_root_path: str,
project_root_path: Path,
) -> str:
return add_needed_imports_from_module(
optimized_code,
@ -224,11 +227,11 @@ def replace_functions_and_add_imports(
def replace_function_definitions_in_module(
function_names: list[str],
optimized_code: str,
file_path_of_module_with_function_to_optimize: str,
module_abspath: str,
file_path_of_module_with_function_to_optimize: Path,
module_abspath: Path,
preexisting_objects: list[tuple[str, list[FunctionParent]]],
contextual_functions: set[tuple[str, str]],
project_root_path: str,
project_root_path: Path,
) -> bool:
""":param function_names: List of qualified (not fully qualified) function names (function_name or
class_name.method_name).
@ -241,8 +244,7 @@ def replace_function_definitions_in_module(
:return:
"""
file: IO[str]
with open(module_abspath, encoding="utf8") as file:
source_code: str = file.read()
source_code: str = module_abspath.read_text(encoding="utf8")
new_code: str = replace_functions_and_add_imports(
source_code,
function_names,
@ -255,8 +257,7 @@ def replace_function_definitions_in_module(
)
if is_zero_diff(source_code, new_code):
return False
with open(module_abspath, "w", encoding="utf8") as file:
file.write(new_code)
module_abspath.write_text(new_code, encoding="utf8")
return True

View file

@ -1,27 +1,26 @@
from __future__ import annotations
import ast
from codeflash.cli_cmds.console import logger
import os
import site
from tempfile import TemporaryDirectory
def module_name_from_file_path(file_path: str, project_root_path: str) -> str:
relative_path = os.path.relpath(file_path, project_root_path)
module_path = relative_path.replace(os.sep, ".")
if module_path.lower().endswith(".py"):
module_path = module_path[:-3]
return module_path
from codeflash.cli_cmds.console import logger
from pathlib import Path
def file_path_from_module_name(module_name: str, project_root_path: str) -> str:
"""Get file path from module path"""
return os.path.join(project_root_path, module_name.replace(".", os.sep) + ".py")
def module_name_from_file_path(file_path: Path, project_root_path: Path) -> str:
relative_path = file_path.relative_to(project_root_path)
return relative_path.with_suffix("").as_posix().replace("/", ".")
def file_path_from_module_name(module_name: str, project_root_path: Path) -> Path:
"""Get file path from module path."""
return project_root_path / (module_name.replace(".", os.sep) + ".py")
def get_imports_from_file(
file_path: str | None = None,
file_path: Path | None = None,
file_string: str | None = None,
file_ast: ast.AST | None = None,
) -> list[ast.Import | ast.ImportFrom]:
@ -29,19 +28,18 @@ def get_imports_from_file(
sum([file_path is not None, file_string is not None, file_ast is not None]) == 1
), "Must provide exactly one of file_path, file_string, or file_ast"
if file_path:
with open(file_path, encoding="utf8") as file:
with file_path.open(encoding="utf8") as file:
file_string = file.read()
if file_ast is None:
if file_string is None:
logger.error("file_string cannot be None when file_ast is not provided")
return []
try:
file_ast = ast.parse(file_string)
except SyntaxError as e:
logger.exception(f"Syntax error in code: {e}")
return []
imports = []
for node in ast.walk(file_ast):
if isinstance(node, (ast.Import, ast.ImportFrom)):
imports.append(node)
return imports
return [node for node in ast.walk(file_ast) if isinstance(node, (ast.Import, ast.ImportFrom))]
def get_all_function_names(code: str) -> tuple[bool, list[str]]:
@ -51,34 +49,27 @@ def get_all_function_names(code: str) -> tuple[bool, list[str]]:
logger.exception(f"Syntax error in code: {e}")
return False, []
function_names = []
for node in ast.walk(module):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
function_names.append(node.name)
function_names = [
node.name for node in ast.walk(module) if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
]
return True, function_names
def get_run_tmp_file(file_path: str) -> str:
def get_run_tmp_file(file_path: Path) -> Path:
if not hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
return os.path.join(get_run_tmp_file.tmpdir.name, file_path)
return Path(get_run_tmp_file.tmpdir.name) / file_path
def path_belongs_to_site_packages(file_path: str) -> bool:
site_packages = site.getsitepackages()
for site_package_path in site_packages:
if file_path.startswith(site_package_path + os.sep):
return True
return False
def path_belongs_to_site_packages(file_path: Path) -> bool:
site_packages = [Path(p) for p in site.getsitepackages()]
return any(file_path.resolve().is_relative_to(site_package_path) for site_package_path in site_packages)
def is_class_defined_in_file(class_name: str, file_path: str) -> bool:
if not os.path.exists(file_path):
def is_class_defined_in_file(class_name: str, file_path: Path) -> bool:
if not file_path.exists():
return False
with open(file_path) as file:
with file_path.open(encoding="utf8") as file:
source = file.read()
tree = ast.parse(source)
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == class_name:
return True
return False
return any(isinstance(node, ast.ClassDef) and node.name == class_name for node in ast.walk(tree))

View file

@ -1,56 +1,53 @@
import os
from __future__ import annotations
from pathlib import Path
from typing import Any
import tomlkit
def find_pyproject_toml(config_file=None):
def find_pyproject_toml(config_file: Path | None = None) -> Path:
# Find the pyproject.toml file on the root of the project
if config_file is not None:
if not config_file.lower().endswith(".toml"):
raise ValueError(
f"Config file {config_file} is not a valid toml file. Please recheck the path to pyproject.toml",
)
if not os.path.exists(config_file):
raise ValueError(
f"Config file {config_file} does not exist. Please recheck the path to pyproject.toml",
)
config_file = Path(config_file)
if config_file.suffix.lower() != ".toml":
msg = f"Config file {config_file} is not a valid toml file. Please recheck the path to pyproject.toml"
raise ValueError(msg)
if not config_file.exists():
msg = f"Config file {config_file} does not exist. Please recheck the path to pyproject.toml"
raise ValueError(msg)
return config_file
else:
dir_path = os.getcwd()
dir_path = Path.cwd()
while os.path.dirname(dir_path) != dir_path:
config_file = os.path.join(dir_path, "pyproject.toml")
if os.path.exists(config_file):
while dir_path != dir_path.parent:
config_file = dir_path / "pyproject.toml"
if config_file.exists():
return config_file
# Search for pyproject.toml in the parent directories
dir_path = os.path.dirname(dir_path)
raise ValueError(
f"Could not find pyproject.toml in the current directory {os.getcwd()} or any of the parent directories. Please create it by running `poetry init`, or pass the path to pyproject.toml with the --config-file argument.",
)
dir_path = dir_path.parent
msg = f"Could not find pyproject.toml in the current directory {Path.cwd()} or any of the parent directories. Please create it by running `poetry init`, or pass the path to pyproject.toml with the --config-file argument."
raise ValueError(msg)
def parse_config_file(config_file_path: str = None) -> tuple[dict[str, Any], str]:
def parse_config_file(config_file_path: Path | None = None) -> tuple[dict[str, Any], Path]:
config_file_path = find_pyproject_toml(config_file_path)
try:
with open(config_file_path, "rb") as f:
with config_file_path.open("rb") as f:
data = tomlkit.parse(f.read())
except tomlkit.exceptions.ParseError as e:
raise ValueError(
f"Error while parsing the config file {config_file_path}. Please recheck the file for syntax errors. Error: {e}",
)
msg = f"Error while parsing the config file {config_file_path}. Please recheck the file for syntax errors. Error: {e}"
raise ValueError(msg) from e
try:
tool = data["tool"]
assert isinstance(tool, dict)
config = tool["codeflash"]
except tomlkit.exceptions.NonExistentKey:
raise ValueError(
f"Could not find the 'codeflash' block in the config file {config_file_path}. "
f"Please run 'codeflash init' to create the config file.",
)
except tomlkit.exceptions.NonExistentKey as e:
msg = f"Could not find the 'codeflash' block in the config file {config_file_path}. Please run 'codeflash init' to create the config file."
raise ValueError(msg) from e
assert isinstance(config, dict)
# default values:
@ -79,9 +76,7 @@ def parse_config_file(config_file_path: str = None) -> tuple[dict[str, Any], str
config[key] = bool_keys[key]
for key in path_keys:
if key in config:
config[key] = os.path.realpath(
os.path.join(os.path.dirname(config_file_path), config[key]),
)
config[key] = str((Path(config_file_path).parent / Path(config[key])).resolve())
for key in list_str_keys:
if key in config:
config[key] = [str(cmd) for cmd in config[key]]
@ -90,10 +85,7 @@ def parse_config_file(config_file_path: str = None) -> tuple[dict[str, Any], str
for key in path_list_keys:
if key in config:
config[key] = [
os.path.realpath(os.path.join(os.path.dirname(config_file_path), path))
for path in config[key]
]
config[key] = [(Path(config_file_path).parent / path).resolve() for path in config[key]]
else: # Default to empty list
config[key] = []

View file

@ -1,28 +1,31 @@
from codeflash.cli_cmds.console import logger
import os.path
from __future__ import annotations
import os
import shlex
import subprocess
from pathlib import Path
import isort
from codeflash.cli_cmds.console import logger
def format_code(
formatter_cmds: list[str],
path: str,
) -> str:
path: Path,
) -> str | None:
# TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution
if not os.path.exists(path):
if not path.exists():
logger.error(f"File {path} does not exist. Cannot format the file.")
return None
if formatter_cmds[0].lower() == "disabled":
with open(path, encoding="utf8") as f:
new_code = f.read()
new_code = path.read_text(encoding="utf8")
return new_code
file_token = "$file"
for command in formatter_cmds:
formatter_cmd_list = shlex.split(command, posix=os.name != "nt")
formatter_cmd_list = [path if chunk == file_token else chunk for chunk in formatter_cmd_list]
formatter_cmd_list = [str(path) if chunk == file_token else chunk for chunk in formatter_cmd_list]
logger.info(f"Formatting code with {' '.join(formatter_cmd_list)} ...")
try:
@ -40,8 +43,7 @@ def format_code(
else:
logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}")
with open(path, encoding="utf8") as f:
new_code = f.read()
new_code = path.read_text(encoding="utf8")
return new_code

View file

@ -1,9 +1,9 @@
from __future__ import annotations
from codeflash.cli_cmds.console import logger
import os
import sys
import time
from io import StringIO
from pathlib import Path
from typing import Optional
import git
@ -12,10 +12,11 @@ from git import Repo
from unidiff import PatchSet
from codeflash.cli_cmds.cli_common import inquirer_wrapper
from codeflash.cli_cmds.console import logger
def get_git_diff(
repo_directory: str = os.getcwd(),
repo_directory: Path = Path.cwd(),
uncommitted_changes: bool = False,
) -> dict[str, list[int]]:
repository = git.Repo(repo_directory, search_parent_directories=True)
@ -37,11 +38,12 @@ def get_git_diff(
patch_set = PatchSet(StringIO(uni_diff_text))
change_list: dict[str, list[int]] = {} # list of changes
for patched_file in patch_set:
file_path: str = patched_file.path # file name
if not file_path.endswith(".py"):
file_path: Path = Path(patched_file.path)
if file_path.suffix != ".py":
continue
file_path = os.path.join(repository.working_dir, file_path)
logger.debug("file name :" + file_path)
file_path = Path(repository.working_dir) / file_path
logger.debug(f"file name: {file_path}")
add_line_no: list[int] = [
line.target_line_no
for hunk in patched_file
@ -49,15 +51,16 @@ def get_git_diff(
if line.is_added and line.value.strip() != ""
] # the row number of deleted lines
logger.debug("added lines : " + str(add_line_no))
logger.debug(f"added lines: {add_line_no}")
del_line_no: list[int] = [
line.source_line_no
for hunk in patched_file
for line in hunk
if line.is_removed and line.value.strip() != ""
] # the row number of added liens
] # the row number of added lines
logger.debug("deleted lines : " + str(del_line_no))
logger.debug(f"deleted lines: {del_line_no}")
change_list[file_path] = add_line_no
return change_list

View file

@ -1,15 +1,16 @@
from codeflash.cli_cmds.console import logger
import os
import re
import shlex
import unittest
from collections import defaultdict
from multiprocessing import Process, Queue
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import jedi
from pydantic.dataclasses import dataclass
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import module_name_from_file_path
from codeflash.verification.test_results import TestType
from codeflash.verification.verification_utils import TestConfig
@ -139,7 +140,7 @@ def discover_tests_unittest(
cfg: TestConfig,
discover_only_these_tests: Optional[List[str]] = None,
) -> Dict[str, List[FunctionCalledInTest]]:
tests_root = cfg.tests_root
tests_root = Path(cfg.tests_root)
loader = unittest.TestLoader()
tests = loader.discover(str(tests_root))
file_to_test_map = defaultdict(list)
@ -151,18 +152,18 @@ def discover_tests_unittest(
_test.__class__.__qualname__,
)
_test_module_path = _test_module.replace(".", os.sep)
_test_module_path = os.path.normpath(os.path.join(str(tests_root), _test_module_path) + ".py")
if not os.path.exists(_test_module_path) or (
discover_only_these_tests and _test_module_path not in discover_only_these_tests
_test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py")
_test_module_path = tests_root / _test_module_path
if not _test_module_path.exists() or (
discover_only_these_tests and str(_test_module_path) not in discover_only_these_tests
):
return None
if "__replay_test" in _test_module_path:
if "__replay_test" in str(_test_module_path):
test_type = TestType.REPLAY_TEST
else:
test_type = TestType.EXISTING_UNIT_TEST
return TestsInFile(
test_file=_test_module_path,
test_file=str(_test_module_path),
test_suite=_test_suite_name,
test_function=_test_function,
test_type=test_type,
@ -186,11 +187,11 @@ def discover_tests_unittest(
continue
details = get_test_details(test_2)
if details is not None:
file_to_test_map[details.test_file].append(details)
file_to_test_map[str(details.test_file)].append(details)
else:
details = get_test_details(test)
if details is not None:
file_to_test_map[details.test_file].append(details)
file_to_test_map[str(details.test_file)].append(details)
return process_test_files(file_to_test_map, cfg)

View file

@ -128,7 +128,7 @@ class FunctionToOptimize:
"""
function_name: str
file_path: str
file_path: Path
parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef]
starting_line: Optional[int] = None
ending_line: Optional[int] = None
@ -150,20 +150,20 @@ class FunctionToOptimize:
def qualified_name(self) -> str:
return self.function_name if self.parents == [] else f"{self.parents[0].name}.{self.function_name}"
def qualified_name_with_modules_from_root(self, project_root_path: str) -> str:
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
def get_functions_to_optimize(
optimize_all: str | None,
replay_test: str | None,
file: str | None,
file: Path | None,
only_get_this_function: str | None,
test_cfg: TestConfig,
ignore_paths: list[str],
project_root: str,
module_root: str,
) -> tuple[dict[str, list[FunctionToOptimize]], int]:
ignore_paths: list[Path],
project_root: Path,
module_root: Path,
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
assert (
sum(
[ # Ensure only one of the options is provided
@ -191,9 +191,8 @@ def get_functions_to_optimize(
if only_get_this_function is not None:
split_function = only_get_this_function.split(".")
if len(split_function) > 2:
raise ValueError(
"Function name should be in the format 'function_name' or 'class_name.function_name'",
)
msg = "Function name should be in the format 'function_name' or 'class_name.function_name'"
raise ValueError(msg)
if len(split_function) == 2:
class_name, only_function_name = split_function
else:
@ -206,10 +205,8 @@ def get_functions_to_optimize(
):
found_function = fn
if found_function is None:
raise ValueError(
f"Function {only_function_name} not found in file {file} or"
f" the function does not have a 'return' statement.",
)
msg = f"Function {only_function_name} not found in file {file} or the function does not have a 'return' statement."
raise ValueError(msg)
functions[file] = [found_function]
else:
logger.info("Finding all functions modified in the current git diff ...")
@ -229,64 +226,60 @@ def get_functions_to_optimize(
def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]:
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=False)
modified_functions: dict[str, list[FunctionToOptimize]] = {}
for path in modified_lines:
if not os.path.exists(path):
for path_str in modified_lines:
path = Path(path_str)
if not path.exists():
continue
with open(path, encoding="utf8") as f:
with path.open(encoding="utf8") as f:
file_content = f.read()
try:
wrapper = cst.metadata.MetadataWrapper(cst.parse_module(file_content))
except Exception as e:
logger.exception(e)
continue
function_lines = FunctionVisitor(file_path=path)
function_lines = FunctionVisitor(file_path=str(path))
wrapper.visit(function_lines)
modified_functions[path] = [
modified_functions[str(path)] = [
function_to_optimize
for function_to_optimize in function_lines.functions
if (start_line := function_to_optimize.starting_line) is not None
and (end_line := function_to_optimize.ending_line) is not None
and any(start_line <= line <= end_line for line in modified_lines[path])
and any(start_line <= line <= end_line for line in modified_lines[path_str])
]
return modified_functions
def get_all_files_and_functions(module_root_path: str) -> dict[str, list[FunctionToOptimize]]:
def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[FunctionToOptimize]]:
functions: dict[str, list[FunctionToOptimize]] = {}
for root, dirs, files in os.walk(module_root_path):
for file in files:
if not file.endswith(".py"):
continue
file_path = os.path.join(root, file)
# Find all the functions in the file
functions.update(find_all_functions_in_file(file_path))
module_root_path = Path(module_root_path)
for file_path in module_root_path.rglob("*.py"):
# Find all the functions in the file
functions.update(find_all_functions_in_file(str(file_path)))
# Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time.
# Helpful if an optimize-all run is stuck and we restart it.
files_list = list(functions.items())
random.shuffle(files_list)
functions_shuffled = dict(files_list)
return functions_shuffled
return dict(files_list)
def find_all_functions_in_file(file_path: str) -> dict[str, list[FunctionToOptimize]]:
def find_all_functions_in_file(file_path: Path) -> dict[str, list[FunctionToOptimize]]:
functions: dict[str, list[FunctionToOptimize]] = {}
with open(file_path, encoding="utf8") as f:
with file_path.open(encoding="utf8") as f:
try:
ast_module = ast.parse(f.read())
except Exception as e:
logger.exception(e)
return functions
function_name_visitor = FunctionWithReturnStatement(file_path)
function_name_visitor = FunctionWithReturnStatement(str(file_path))
function_name_visitor.visit(ast_module)
functions[file_path] = function_name_visitor.functions
functions[str(file_path)] = function_name_visitor.functions
return functions
def get_all_replay_test_functions(
replay_test: str,
test_cfg: TestConfig,
project_root_path: str,
project_root_path: Path,
) -> dict[str, list[FunctionToOptimize]]:
function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
# Get the absolute file paths for each function, excluding class name if present
@ -303,7 +296,7 @@ def get_all_replay_test_functions(
if module_path_parts
and is_class_defined_in_file(
module_path_parts[-1],
os.path.join(project_root_path, *module_path_parts[:-1]) + ".py",
str(Path(project_root_path, *module_path_parts[:-1])) + ".py",
)
else None
)
@ -314,7 +307,7 @@ def get_all_replay_test_functions(
else:
function = function_name
file_path_parts = module_path_parts
file_path = os.path.join(project_root_path, *file_path_parts) + ".py"
file_path = Path(project_root_path, *file_path_parts).with_suffix(".py")
file_to_functions_map[file_path].append((function, function_name, class_name))
for file_path, functions in file_to_functions_map.items():
all_valid_functions: dict[str, list[FunctionToOptimize]] = find_all_functions_in_file(
@ -349,8 +342,7 @@ def ignored_submodule_paths(module_root: str) -> list[str]:
if is_git_repo(module_root):
git_repo = git.Repo(module_root, search_parent_directories=True)
return [
os.path.realpath(os.path.join(git_repo.working_tree_dir, submodule.path))
for submodule in git_repo.submodules
Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules
]
return []
@ -358,7 +350,7 @@ def ignored_submodule_paths(module_root: str) -> list[str]:
class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
def __init__(
self,
file_name: str,
file_name: Path,
function_or_method_name: str,
class_name: str | None = None,
line_no: int | None = None,
@ -418,7 +410,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
def inspect_top_level_functions_or_methods(
file_name: str,
file_name: Path,
function_or_method_name: str,
class_name: str | None = None,
line_no: int | None = None,
@ -449,11 +441,11 @@ def inspect_top_level_functions_or_methods(
def filter_functions(
modified_functions: dict[str, list[FunctionToOptimize]],
tests_root: str,
ignore_paths: list[str],
ignore_paths: list[Path],
project_root: str,
module_root: str,
disable_logs: bool = False,
) -> tuple[dict[str, list[FunctionToOptimize]], int]:
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
blocklist_funcs = get_blacklisted_functions()
# Remove any function that we don't want to optimize
@ -522,37 +514,33 @@ def filter_functions(
def filter_files_optimized(
file_path: str,
tests_root: str,
ignore_paths: list[str],
module_root: str,
file_path: Path,
tests_root: Path,
ignore_paths: list[Path],
module_root: Path,
) -> bool:
"""Optimized version of the filter_functions function above.
Takes in file paths and returns the count of files that are to be optimized.
"""
submodule_paths = None
if file_path.startswith(tests_root + os.sep):
if file_path.is_relative_to(tests_root):
return False
if file_path in ignore_paths or any(
file_path.startswith(ignore_path + os.sep) for ignore_path in ignore_paths
file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths
):
return False
if path_belongs_to_site_packages(file_path):
return False
if not file_path.startswith(module_root + os.sep):
if not file_path.is_relative_to(module_root):
return False
if submodule_paths is None:
submodule_paths = ignored_submodule_paths(module_root)
if file_path in submodule_paths or any(
file_path.startswith(submodule_path + os.sep) for submodule_path in submodule_paths
):
return False
return True
return not (
file_path in submodule_paths
or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths)
)
def function_has_return_statement(function_node: Union[FunctionDef, AsyncFunctionDef]) -> bool:
for node in ast.walk(function_node):
if isinstance(node, ast.Return):
return True
return False
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
return any(isinstance(node, ast.Return) for node in ast.walk(function_node))

View file

@ -1,5 +1,6 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Generator, Iterator, Optional
from jedi.api.classes import Name
@ -17,7 +18,7 @@ from codeflash.verification.test_results import TestResults, TestType
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
class FunctionSource:
file_path: str
file_path: Path
qualified_name: str
fully_qualified_name: str
only_function_name: str
@ -55,7 +56,7 @@ class GeneratedTestsList(BaseModel):
class TestFile(BaseModel):
instrumented_file_path: str
instrumented_file_path: Path
original_file_path: Optional[str] = None
original_source: Optional[str] = None
test_type: TestType

View file

@ -4,16 +4,20 @@ import ast
import os
import re
from collections import defaultdict
from typing import TYPE_CHECKING
import jedi
import tiktoken
from jedi.api.classes import Name
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import get_code
from codeflash.code_utils.code_utils import module_name_from_file_path, path_belongs_to_site_packages
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
from codeflash.models.models import FunctionSource
from codeflash.cli_cmds.console import logger
if TYPE_CHECKING:
from pathlib import Path
def belongs_to_class(name: Name, class_name: str) -> bool:
@ -36,12 +40,11 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
def get_type_annotation_context(
function: FunctionToOptimize,
jedi_script: jedi.Script,
project_root_path: str,
project_root_path: Path,
) -> list[FunctionSource]:
function_name: str = function.function_name
file_path: str = function.file_path
with open(file_path, encoding="utf8") as file:
file_contents: str = file.read()
file_path: Path = function.file_path
file_contents: str = file_path.read_text(encoding="utf8")
try:
module: ast.Module = ast.parse(file_contents)
except SyntaxError as e:
@ -72,14 +75,15 @@ def get_type_annotation_context(
logger.exception(f"Error while getting definition: {ex}")
definition = []
if definition: # TODO can be multiple definitions
definition_path = str(definition[0].module_path)
definition_path = definition[0].module_path
# The definition is part of this project and not defined within the original function
if (
definition_path.startswith(project_root_path + os.sep)
str(definition_path).startswith(str(project_root_path) + os.sep)
and definition[0].full_name
and not path_belongs_to_site_packages(definition_path)
and not belongs_to_function(definition[0], function_name)
):
assert definition_path is not None
source_code = get_code(
[
FunctionToOptimize(
@ -164,7 +168,7 @@ def get_type_annotation_context(
def get_function_variables_definitions(
function_to_optimize: FunctionToOptimize,
project_root_path: str,
project_root_path: Path,
) -> tuple[list[FunctionSource], set[tuple[str, str]]]:
function_name = function_to_optimize.function_name
file_path = function_to_optimize.file_path
@ -203,10 +207,11 @@ def get_function_variables_definitions(
if definitions:
# TODO: there can be multiple definitions, see how to handle such cases
definition = definitions[0]
definition_path = str(definition.module_path)
definition_path = definition.module_path
assert definition_path is not None
# The definition is part of this project and not defined within the original function
if (
definition_path.startswith(project_root_path + os.sep)
str(definition_path).startswith(str(project_root_path) + os.sep)
and not path_belongs_to_site_packages(definition_path)
and definition.full_name
and not belongs_to_function(definition, function_name)
@ -255,7 +260,7 @@ def get_function_variables_definitions(
for source in sources:
if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names:
if not source.qualified_name.count("."):
no_parent_sources[source.file_path][source.qualified_name].add(source)
no_parent_sources[str(source.file_path)][source.qualified_name].add(source)
else:
parent_sources.add(source)
existing_fully_qualified_names.add(fully_qualified_name)
@ -263,7 +268,7 @@ def get_function_variables_definitions(
source
for source in parent_sources
if source.file_path not in no_parent_sources
or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path]
or source.qualified_name.rpartition(".")[0] not in no_parent_sources[str(source.file_path)]
]
deduped_no_parent_sources = [
source
@ -279,7 +284,7 @@ MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k
def get_constrained_function_context_and_helper_functions(
function_to_optimize: FunctionToOptimize,
project_root_path: str,
project_root_path: Path,
code_to_optimize: str,
max_tokens: int = MAX_PROMPT_TOKENS,
) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]:

View file

@ -2,11 +2,11 @@ from __future__ import annotations
import concurrent.futures
import os
import pathlib
import subprocess
import time
import uuid
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING
import isort
@ -623,7 +623,7 @@ class Optimizer:
def replace_function_and_helpers_with_optimized_code(
self,
code_context: CodeOptimizationContext,
function_to_optimize_file_path: str,
function_to_optimize_file_path: Path,
optimized_code: str,
qualified_function_name: str,
) -> bool:
@ -660,7 +660,7 @@ class Optimizer:
def get_code_optimization_context(
self,
function_to_optimize: FunctionToOptimize,
project_root: str,
project_root: Path,
original_source_code: str,
) -> Result[CodeOptimizationContext, str]:
code_to_optimize, contextual_dunder_methods = extract_code(
@ -687,7 +687,7 @@ class Optimizer:
optimizable_methods = [
FunctionToOptimize(
df.qualified_name.split(".")[-1],
df.file_path,
Path(df.file_path),
[FunctionParent(df.qualified_name.split(".")[0], "ClassDef")],
None,
None,
@ -730,18 +730,14 @@ class Optimizer:
@staticmethod
def cleanup_leftover_test_return_values() -> None:
# remove leftovers from previous run
pathlib.Path(get_run_tmp_file("test_return_values_0.bin")).unlink(
missing_ok=True,
)
pathlib.Path(get_run_tmp_file("test_return_values_0.sqlite")).unlink(
missing_ok=True,
)
get_run_tmp_file(Path("test_return_values_0.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path("test_return_values_0.sqlite")).unlink(missing_ok=True)
def instrument_existing_tests(
self,
function_to_optimize: FunctionToOptimize,
function_to_tests: dict[str, list[TestsInFile]],
) -> set[str]:
) -> set[Path]:
relevant_test_files_count = 0
unique_instrumented_test_files = set()
@ -767,10 +763,8 @@ class Optimizer:
)
if not success:
continue
new_test_path = (
f"{os.path.splitext(test_file)[0]}__perfinstrumented{os.path.splitext(test_file)[1]}"
)
with pathlib.Path(new_test_path).open("w", encoding="utf8") as f:
new_test_path = Path(test_file).with_suffix(f"__perfinstrumented{Path(test_file).suffix}")
with new_test_path.open("w", encoding="utf8") as f:
f.write(injected_test)
unique_instrumented_test_files.add(new_test_path)
if not self.test_files.get_by_original_file_path(test_file):
@ -793,7 +787,7 @@ class Optimizer:
code_to_optimize_with_helpers: str,
function_to_optimize: FunctionToOptimize,
helper_functions: list[FunctionSource],
module_path: str,
module_path: Path,
function_trace_id: str,
run_experiment: bool = False,
) -> Result[tuple[GeneratedTestsList, OptimizationSet], str]:
@ -995,12 +989,12 @@ class Optimizer:
first_test_types = []
first_test_functions = []
pathlib.Path(
get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin"),
).unlink(missing_ok=True)
pathlib.Path(
get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite"),
).unlink(missing_ok=True)
Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin")).unlink(
missing_ok=True,
)
Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(
missing_ok=True,
)
for test_file in instrumented_unittests_created_for_function:
relevant_tests_in_file = [
@ -1069,12 +1063,10 @@ class Optimizer:
)
if best_runtime_until_now is None or total_candidate_timing < best_runtime_until_now:
best_test_results = candidate_results
pathlib.Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin")).unlink(
Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin")).unlink(
missing_ok=True,
)
pathlib.Path(
get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite"),
).unlink(
Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(
missing_ok=True,
)
if not equal_results:
@ -1142,7 +1134,7 @@ class Optimizer:
source_code_being_tested: str,
function_to_optimize: FunctionToOptimize,
helper_function_names: list[str],
module_path: str,
module_path: Path,
function_trace_id: str,
) -> GeneratedTestsList | None:
futures = [

View file

@ -9,6 +9,8 @@
# Licensed under the Apache License, Version 2.0 (the "License").
# http://www.apache.org/licenses/LICENSE-2.0
#
from __future__ import annotations
import importlib.machinery
import io
import json
@ -23,20 +25,21 @@ import time
from collections import defaultdict
from copy import copy
from io import StringIO
from pathlib import Path
from types import FrameType
from typing import Any, List, Optional
from typing import Any, ClassVar, List, Optional
import dill
import isort
from codeflash.cli_cmds.cli import project_root_from_module_root
from codeflash.cli_cmds.console import console
from codeflash.code_utils.code_utils import module_name_from_file_path
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.discovery.functions_to_optimize import filter_files_optimized
from codeflash.tracing.replay_test import create_trace_replay_test
from codeflash.tracing.tracing_utils import FunctionModules
from codeflash.verification.verification_utils import get_test_file_path
from codeflash.cli_cmds.console import console
class Tracer:
@ -49,7 +52,7 @@ class Tracer:
output: str = "codeflash.trace",
functions: Optional[List[str]] = None,
disable: bool = False,
config_file_path: Optional[str] = None,
config_file_path: Path | None = None,
max_function_count: int = 256,
timeout: Optional[int] = None, # seconds
) -> None:
@ -77,13 +80,14 @@ class Tracer:
self.disable = True
return
self.con = None
self.output_file = os.path.abspath(output)
self.output_file = Path(output).resolve()
self.functions = functions
self.function_modules: List[FunctionModules] = []
self.function_count = defaultdict(int)
self.current_file_path = Path(__file__).resolve()
self.ignored_qualified_functions = {
f"{os.path.realpath(__file__)}:Tracer:__exit__",
f"{os.path.realpath(__file__)}:Tracer:__enter__",
f"{self.current_file_path}:Tracer:__exit__",
f"{self.current_file_path}:Tracer:__enter__",
}
self.max_function_count = max_function_count
self.config, found_config_path = parse_config_file(config_file_path)
@ -99,13 +103,9 @@ class Tracer:
"<lambda>",
"<module>",
}
self.file_being_called_from: str = str(
os.path.basename(
os.path.realpath(sys._getframe().f_back.f_code.co_filename),
).replace(
".",
"_",
),
self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace(
".", "_"
)
assert timeout is None or timeout > 0, "Timeout should be greater than 0"
@ -175,7 +175,7 @@ class Tracer:
cur.execute(
"INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
os.path.realpath(func[0]),
Path(func[0]).resolve(),
func[1],
func[2],
func[3],
@ -245,7 +245,7 @@ class Tracer:
if code.co_name in self.ignored_functions:
return
if not os.path.exists(file_name):
if not Path(file_name).exists():
return
if self.functions:
if code.co_name not in self.functions:
@ -264,7 +264,7 @@ class Tracer:
except:
# someone can override the getattr method and raise an exception. I'm looking at you wrapt
return
file_name = os.path.realpath(file_name)
file_name = Path(file_name).resolve()
function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}"
if function_qualified_name in self.ignored_qualified_functions:
return
@ -311,7 +311,7 @@ class Tracer:
# We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class
# directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory
# leaks, bad references or side-effects when unpickling.
arguments = {k: v for k, v in arguments.items()}
arguments = dict(arguments.items())
if class_name and code.co_name == "__init__":
del arguments["self"]
local_vars = pickle.dumps(
@ -463,7 +463,7 @@ class Tracer:
return 1
dispatch = {
dispatch: ClassVar[dict[str, callable]] = {
"call": trace_dispatch_call,
"exception": trace_dispatch_exception,
"return": trace_dispatch_return,
@ -648,7 +648,7 @@ def main():
# The script that we're profiling may chdir, so capture the absolute path
# to the output file at startup.
if args.outfile is not None:
args.outfile = os.path.abspath(args.outfile)
args.outfile = Path(args.outfile).resolve()
if len(unknown_args) > 0:
if args.module:
@ -661,7 +661,7 @@ def main():
}
else:
progname = unknown_args[0]
sys.path.insert(0, os.path.dirname(progname))
sys.path.insert(0, str(Path(progname).parent))
with io.open_code(progname) as fp:
code = compile(fp.read(), progname, "exec")
spec = importlib.machinery.ModuleSpec(name="__main__", loader=None, origin=progname)

View file

@ -1,4 +1,6 @@
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from pathlib import Path
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
def test_add_needed_imports_from_module0() -> None:
@ -46,9 +48,9 @@ class Source:
expected = """def heyjude() -> None:
print("Hey Jude, don't make it bad")
"""
src_path = "/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py"
dst_path = "/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py"
project_root = "/home/roger/repos/codeflash"
src_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
dst_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
project_root = Path("/home/roger/repos/codeflash")
new_module = add_needed_imports_from_module(
src_module,
dst_module,
@ -120,9 +122,9 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
# The name is defined inside the function or is the function itself
return f".{function_name}." in subname or f".{function_name}" == subname
'''
src_path = "/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py"
dst_path = "/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py"
project_root = "/home/roger/repos/codeflash"
src_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
dst_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py")
project_root = Path("/home/roger/repos/codeflash")
new_module = add_needed_imports_from_module(
src_module,
dst_module,

View file

@ -25,7 +25,7 @@ class JediDefinition:
@dataclasses.dataclass
class FakeFunctionSource:
file_path: str
file_path: Path
qualified_name: str
fully_qualified_name: str
only_function_name: str
@ -81,11 +81,11 @@ print("Hello world")
source_code=original_code,
function_names=[function_name],
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -145,11 +145,11 @@ print("Hello world")
source_code=original_code,
function_names=[function_name],
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -206,11 +206,11 @@ print("Salut monde")
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -270,11 +270,11 @@ print("Salut monde")
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -326,11 +326,11 @@ def supersort(doink):
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -402,11 +402,11 @@ print("Not cool")
source_code=original_code_main,
function_names=["other_function"],
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=[("other_function", []), ("yet_another_function", []), ("blob", [])],
contextual_functions=set(),
project_root_path=str(Path(__file__).resolve().parent.resolve()),
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_main_code == expected_main
@ -414,11 +414,11 @@ print("Not cool")
source_code=original_code_helper,
function_names=["blob"],
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=[],
contextual_functions=set(),
project_root_path=str(Path(__file__).resolve().parent.resolve()),
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_helper_code == expected_helper
@ -619,11 +619,11 @@ class CacheConfig(BaseConfig):
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -694,11 +694,11 @@ def test_test_libcst_code_replacement8() -> None:
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -757,11 +757,11 @@ print("Hello world")
source_code=original_code,
function_names=[function_name],
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
@ -817,17 +817,16 @@ class MainClass:
)
func_top_optimize = FunctionToOptimize(
function_name="main_method",
file_path=str(file_path),
file_path=file_path,
parents=[FunctionParent("MainClass", "ClassDef")],
)
with open(file_path) as f:
original_code = f.read()
code_context = opt.get_code_optimization_context(
function_to_optimize=func_top_optimize,
project_root=str(file_path.parent),
original_source_code=original_code,
).unwrap()
assert code_context.code_to_optimize_with_helpers == get_code_output
original_code = file_path.read_text()
code_context = opt.get_code_optimization_context(
function_to_optimize=func_top_optimize,
project_root=file_path.parent,
original_source_code=original_code,
).unwrap()
assert code_context.code_to_optimize_with_helpers == get_code_output
def test_code_replacement11() -> None:
@ -946,11 +945,11 @@ def test_test_libcst_code_replacement13() -> None:
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).resolve().parent.resolve()),
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == original_code
@ -1143,7 +1142,9 @@ class TestResults(BaseModel):
helper_functions = [
FakeFunctionSource(
file_path="/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/codeflash/verification/test_results.py",
file_path=Path(
"/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/codeflash/verification/test_results.py"
),
qualified_name="TestType",
fully_qualified_name="codeflash.verification.test_results.TestType",
only_function_name="TestType",
@ -1156,11 +1157,11 @@ class TestResults(BaseModel):
source_code=original_code,
function_names=["TestResults.get_test_pass_fail_report_by_type"],
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str(Path(__file__).resolve()),
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).parent.resolve()),
project_root_path=Path(__file__).parent.resolve(),
)
helper_functions_by_module_abspath = defaultdict(set)
@ -1177,11 +1178,11 @@ class TestResults(BaseModel):
source_code=new_code,
function_names=list(qualified_names),
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=module_abspath,
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).parent.resolve()),
project_root_path=Path(__file__).parent.resolve(),
)
assert (
@ -1361,12 +1362,16 @@ def cosine_similarity_top_k(
return ret_idxs, scores
'''
preexisting_objects = [("cosine_similarity_top_k", []), ("Matrix", []), ("cosine_similarity", [])]
preexisting_objects: list[tuple[str, list[FunctionParent]]] = [
("cosine_similarity_top_k", []),
("Matrix", []),
("cosine_similarity", []),
]
contextual_functions = set()
contextual_functions: set[tuple[str, str]] = set()
helper_functions = [
FakeFunctionSource(
file_path=str((Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve()),
file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(),
qualified_name="Matrix",
fully_qualified_name="code_to_optimize.math_utils.Matrix",
only_function_name="Matrix",
@ -1374,7 +1379,7 @@ def cosine_similarity_top_k(
jedi_definition=JediDefinition(type="class"),
),
FakeFunctionSource(
file_path=str((Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve()),
file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(),
qualified_name="cosine_similarity",
fully_qualified_name="code_to_optimize.math_utils.cosine_similarity",
only_function_name="cosine_similarity",
@ -1387,11 +1392,11 @@ def cosine_similarity_top_k(
source_code=original_code,
function_names=["cosine_similarity_top_k"],
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
module_abspath=str((Path(__file__).parent / "code_to_optimize").resolve()),
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=(Path(__file__).parent / "code_to_optimize").resolve(),
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).parent.parent.resolve()),
project_root_path=Path(__file__).parent.parent.resolve(),
)
assert (
new_code
@ -1452,11 +1457,11 @@ def cosine_similarity_top_k(
source_code=new_code,
function_names=list(qualified_names),
optimized_code=optim_code,
file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()),
file_path_of_module_with_function_to_optimize=Path(__file__).resolve(),
module_abspath=module_abspath,
preexisting_objects=preexisting_objects,
contextual_functions=contextual_functions,
project_root_path=str(Path(__file__).parent.parent.resolve()),
project_root_path=Path(__file__).parent.parent.resolve(),
)
assert (

View file

@ -0,0 +1,92 @@
import ast
from pathlib import Path
import pytest
from codeflash.code_utils.code_utils import get_imports_from_file, module_name_from_file_path
# tests for module_name_from_file_path
def test_module_name_from_file_path() -> None:
project_root_path = Path("/Users/codeflashuser/PycharmProjects/codeflash")
file_path = project_root_path / "cli/codeflash/code_utils/code_utils.py"
module_name = module_name_from_file_path(file_path, project_root_path)
assert module_name == "cli.codeflash.code_utils.code_utils"
def test_module_name_from_file_path_with_subdirectory() -> None:
project_root_path = Path("/Users/codeflashuser/PycharmProjects/codeflash")
file_path = project_root_path / "cli/codeflash/code_utils/subdir/code_utils.py"
module_name = module_name_from_file_path(file_path, project_root_path)
assert module_name == "cli.codeflash.code_utils.subdir.code_utils"
def test_module_name_from_file_path_with_different_root() -> None:
project_root_path = Path("/Users/codeflashuser/PycharmProjects")
file_path = project_root_path / "codeflash/cli/codeflash/code_utils/code_utils.py"
module_name = module_name_from_file_path(file_path, project_root_path)
assert module_name == "codeflash.cli.codeflash.code_utils.code_utils"
def test_module_name_from_file_path_with_root_as_file() -> None:
project_root_path = Path("/Users/codeflashuser/PycharmProjects/codeflash/cli/codeflash/code_utils")
file_path = project_root_path / "code_utils.py"
module_name = module_name_from_file_path(file_path, project_root_path)
assert module_name == "code_utils"
# def test_get_imports_from_file_with_file_path(tmp_path: Path):
# test_file = tmp_path / "test_file.py"
# test_file.write_text("import os\nfrom sys import path\n")
# imports = get_imports_from_file(file_path=test_file)
# assert len(imports) == 2
# assert isinstance(imports[0], ast.Import)
# assert isinstance(imports[1], ast.ImportFrom)
# assert imports[0].names[0].name == "os"
# assert imports[1].module == "sys"
# assert imports[1].names[0].name == "path"
# def test_get_imports_from_file_with_file_string():
# file_string = "import os\nfrom sys import path\n"
# imports = get_imports_from_file(file_string=file_string)
# assert len(imports) == 2
# assert isinstance(imports[0], ast.Import)
# assert isinstance(imports[1], ast.ImportFrom)
# assert imports[0].names[0].name == "os"
# assert imports[1].module == "sys"
# assert imports[1].names[0].name == "path"
# def test_get_imports_from_file_with_file_ast():
# file_string = "import os\nfrom sys import path\n"
# file_ast = ast.parse(file_string)
# imports = get_imports_from_file(file_ast=file_ast)
# assert len(imports) == 2
# assert isinstance(imports[0], ast.Import)
# assert isinstance(imports[1], ast.ImportFrom)
# assert imports[0].names[0].name == "os"
# assert imports[1].module == "sys"
# assert imports[1].names[0].name == "path"
# def test_get_imports_from_file_with_syntax_error(caplog):
# file_string = "import os\nfrom sys import path\ninvalid syntax"
# imports = get_imports_from_file(file_string=file_string)
# assert len(imports) == 0
# assert "Syntax error in code" in caplog.text
# def test_get_imports_from_file_with_no_input():
# with pytest.raises(
# AssertionError, match="Must provide exactly one of file_path, file_string, or file_ast"
# ):
# get_imports_from_file()

View file

@ -1,7 +1,9 @@
import os
import tempfile
from pathlib import Path
import pytest
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.formatter import format_code, sort_imports
@ -34,12 +36,13 @@ def test_sort_imports_without_formatting():
with tempfile.NamedTemporaryFile() as tmp:
tmp.write(b"import sys\nimport unittest\nimport os\n")
tmp.flush()
tmp_path = tmp.name
tmp_path = Path(tmp.name)
new_code = format_code(
formatter_cmds=["disabled"],
path=tmp_path,
)
assert new_code is not None
new_code = sort_imports(new_code)
assert new_code == "import os\nimport sys\nimport unittest\n"
@ -108,7 +111,7 @@ ignore-paths = []
with tempfile.NamedTemporaryFile(suffix=".toml", delete=False) as tmp:
tmp.write(config_data.encode())
tmp.flush()
tmp_path = tmp.name
tmp_path = Path(tmp.name)
try:
config, _ = parse_config_file(tmp_path)
@ -135,7 +138,7 @@ def foo():
actual = format_code(
formatter_cmds=["black $file"],
path=tmp_path,
path=Path(tmp_path),
)
assert actual == expected
@ -160,7 +163,7 @@ def foo():
actual = format_code(
formatter_cmds=["black $file"],
path=tmp_path,
path=Path(tmp_path),
)
assert actual == expected
@ -189,6 +192,6 @@ def foo():
actual = format_code(
formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"],
path=tmp_path,
path=Path(tmp_path),
)
assert actual == expected