From 5cd94cdf642dffdfa0c4ce321fe97a9e9e7f16a6 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sat, 12 Oct 2024 17:29:15 -0500 Subject: [PATCH] round 1 --- codeflash/cli_cmds/cli.py | 39 +++--- codeflash/cli_cmds/cmd_init.py | 113 ++++++++--------- codeflash/code_utils/code_extractor.py | 15 +-- codeflash/code_utils/code_replacer.py | 23 ++-- codeflash/code_utils/code_utils.py | 65 +++++----- codeflash/code_utils/config_parser.py | 64 +++++----- codeflash/code_utils/formatter.py | 22 ++-- codeflash/code_utils/git_utils.py | 23 ++-- codeflash/discovery/discover_unit_tests.py | 21 ++-- codeflash/discovery/functions_to_optimize.py | 112 ++++++++--------- codeflash/models/models.py | 5 +- codeflash/optimization/function_context.py | 31 +++-- codeflash/optimization/optimizer.py | 46 +++---- codeflash/tracer.py | 40 +++--- tests/test_add_needed_imports_from_module.py | 16 +-- tests/test_code_replacement.py | 121 ++++++++++--------- tests/test_code_utils.py | 92 ++++++++++++++ tests/test_formatter.py | 13 +- 18 files changed, 470 insertions(+), 391 deletions(-) create mode 100644 tests/test_code_utils.py diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 44360dc5d..d01f240b1 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -1,13 +1,14 @@ import logging -import os import sys from argparse import SUPPRESS, ArgumentParser, Namespace +from pathlib import Path import git from codeflash.cli_cmds import logging_config from codeflash.cli_cmds.cli_common import apologize_and_exit from codeflash.cli_cmds.cmd_init import init_codeflash, install_github_actions +from codeflash.cli_cmds.console import logger from codeflash.code_utils import env_utils from codeflash.code_utils.config_parser import parse_config_file from codeflash.code_utils.git_utils import ( @@ -17,7 +18,6 @@ from codeflash.code_utils.git_utils import ( get_repo_owner_and_name, ) from codeflash.code_utils.github_utils import get_github_secrets_page_url, require_github_app_or_exit -from codeflash.cli_cmds.console import logger from codeflash.version import __version__ as version @@ -104,18 +104,18 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace: logger.error("If you specify a --function, you must specify the --file it is in") sys.exit(1) if args.file: - if not os.path.exists(args.file): + if not Path(args.file).exists(): logger.error(f"File {args.file} does not exist") sys.exit(1) - args.file = os.path.realpath(args.file) + args.file = Path(args.file).resolve() if not args.no_pr: owner, repo = get_repo_owner_and_name() require_github_app_or_exit(owner, repo) if args.replay_test: - if not os.path.isfile(args.replay_test): + if not Path(args.replay_test).is_file(): logger.error(f"Replay test file {args.replay_test} does not exist") sys.exit(1) - args.replay_test = os.path.realpath(args.replay_test) + args.replay_test = Path(args.replay_test).resolve() return args @@ -142,11 +142,11 @@ def process_pyproject_config(args: Namespace) -> Namespace: or not hasattr(args, key.replace("-", "_")) ): setattr(args, key.replace("-", "_"), pyproject_config[key]) - assert args.module_root is not None and os.path.isdir( - args.module_root, + assert ( + args.module_root is not None and Path(args.module_root).is_dir() ), f"--module-root {args.module_root} must be a valid directory" - assert args.tests_root is not None and os.path.isdir( - args.tests_root, + assert ( + args.tests_root is not None and Path(args.tests_root).is_dir() ), f"--tests-root {args.tests_root} must be a valid directory" assert not (env_utils.get_pr_number() is not None and not env_utils.ensure_codeflash_api_key()), ( @@ -160,24 +160,23 @@ def process_pyproject_config(args: Namespace) -> Namespace: if hasattr(args, "ignore_paths") and args.ignore_paths is not None: normalized_ignore_paths = [] for path in args.ignore_paths: - assert os.path.exists( - path, - ), f"ignore-paths config must be a valid path. Path {path} does not exist" - normalized_ignore_paths.append(os.path.realpath(path)) + path_obj = Path(path) + assert path_obj.exists(), f"ignore-paths config must be a valid path. Path {path} does not exist" + normalized_ignore_paths.append(path_obj.resolve()) args.ignore_paths = normalized_ignore_paths # Project root path is one level above the specified directory, because that's where the module can be imported from - args.module_root = os.path.realpath(args.module_root) + args.module_root = Path(args.module_root).resolve() # If module-root is "." then all imports are relatives to it. # in this case, the ".." becomes outside project scope, causing issues with un-importable paths args.project_root = project_root_from_module_root(args.module_root, pyproject_file_path) - args.tests_root = os.path.realpath(args.tests_root) + args.tests_root = Path(args.tests_root).resolve() return handle_optimize_all_arg_parsing(args) -def project_root_from_module_root(module_root: str, pyproject_file_path: str) -> str: - if os.path.dirname(pyproject_file_path) == module_root: +def project_root_from_module_root(module_root: Path, pyproject_file_path: Path) -> Path: + if pyproject_file_path.parent == module_root: return module_root - return os.path.realpath(os.path.join(module_root, "..")) + return module_root.parent.resolve() def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace: @@ -203,5 +202,5 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace: # The default behavior of --all is to optimize everything in args.module_root args.all = args.module_root else: - args.all = os.path.realpath(args.all) + args.all = Path(args.all).resolve() return args diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 194a6d5ff..81345d5f7 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -2,11 +2,11 @@ from __future__ import annotations import ast import os -import pathlib import re import subprocess import sys from argparse import Namespace +from pathlib import Path from typing import Optional import click @@ -107,7 +107,7 @@ def ask_run_end_to_end_test(args: Namespace) -> None: def collect_setup_info() -> SetupInfo: - curdir = os.getcwd() + curdir = Path.cwd() # Check if the cwd is writable if not os.access(curdir, os.W_OK): click.echo( @@ -170,20 +170,22 @@ def collect_setup_info() -> SetupInfo: ) if tests_root_answer == create_for_me_option: - tests_root = os.path.join(curdir, default_tests_subdir) - os.mkdir(tests_root) - click.echo(f"✅ Created directory {tests_root}{os.pathsep}{LF}") + tests_root = Path(curdir) / default_tests_subdir + tests_root.mkdir() + click.echo(f"✅ Created directory {tests_root}{os.path.sep}{LF}") elif tests_root_answer == custom_dir_option: custom_tests_root_answer = inquirer_wrapper_path( "path", - message=f"Enter the path to your tests directory inside {os.path.abspath('.') + os.path.sep} ", + message=f"Enter the path to your tests directory inside {Path(curdir).resolve()}{os.path.sep} ", path_type=inquirer.Path.DIRECTORY, exists=True, ) - tests_root = custom_tests_root_answer["path"] if custom_tests_root_answer else apologize_and_exit() + tests_root = ( + Path(custom_tests_root_answer["path"]) if custom_tests_root_answer else apologize_and_exit() + ) else: - tests_root = tests_root_answer - tests_root = os.path.relpath(tests_root, curdir) + tests_root = Path(tests_root_answer) + tests_root = tests_root.relative_to(curdir) ph("cli-tests-root-provided") # Autodiscover test framework @@ -224,7 +226,7 @@ def collect_setup_info() -> SetupInfo: ) -def detect_test_framework(curdir, tests_root) -> Optional[str]: +def detect_test_framework(curdir: Path, tests_root: Path) -> Optional[str]: test_framework = None pytest_files = ["pytest.ini", "pyproject.toml", "tox.ini", "setup.cfg"] pytest_config_patterns = { @@ -234,9 +236,9 @@ def detect_test_framework(curdir, tests_root) -> Optional[str]: "setup.cfg": "[tool:pytest]", } for pytest_file in pytest_files: - file_path = os.path.join(curdir, pytest_file) - if os.path.exists(file_path): - with open(file_path, encoding="utf8") as file: + file_path = curdir / pytest_file + if file_path.exists(): + with file_path.open(encoding="utf8") as file: contents = file.read() if pytest_config_patterns[pytest_file] in contents: test_framework = "pytest" @@ -244,9 +246,9 @@ def detect_test_framework(curdir, tests_root) -> Optional[str]: test_framework = "pytest" else: # Check if any python files contain a class that inherits from unittest.TestCase - for filename in os.listdir(tests_root): - if filename.endswith(".py"): - with open(os.path.join(tests_root, filename), encoding="utf8") as file: + for filename in tests_root.iterdir(): + if filename.suffix == ".py": + with filename.open(encoding="utf8") as file: contents = file.read() try: node = ast.parse(contents) @@ -271,14 +273,13 @@ def detect_test_framework(curdir, tests_root) -> Optional[str]: def check_for_toml_or_setup_file() -> Optional[str]: click.echo() click.echo("Checking for pyproject.toml or setup.py ...\r", nl=False) - curdir = os.getcwd() - pyproject_toml_path = os.path.join(curdir, "pyproject.toml") - setup_py_path = os.path.join(curdir, "setup.py") + curdir = Path.cwd() + pyproject_toml_path = curdir / "pyproject.toml" + setup_py_path = curdir / "setup.py" project_name = None - if os.path.exists(pyproject_toml_path): + if pyproject_toml_path.exists(): try: - with open(pyproject_toml_path, encoding="utf8") as f: - pyproject_toml_content = f.read() + pyproject_toml_content = pyproject_toml_path.read_text(encoding="utf8") project_name = tomlkit.parse(pyproject_toml_content)["tool"]["poetry"]["name"] click.echo(f"✅ I found a pyproject.toml for your project {project_name}.") ph("cli-pyproject-toml-found-name") @@ -286,9 +287,8 @@ def check_for_toml_or_setup_file() -> Optional[str]: click.echo("✅ I found a pyproject.toml for your project.") ph("cli-pyproject-toml-found") else: - if os.path.exists(setup_py_path): - with open(setup_py_path, encoding="utf8") as f: - setup_py_content = f.read() + if setup_py_path.exists(): + setup_py_content = setup_py_path.read_text(encoding="utf8") project_name_match = re.search( r"setup\s*\([^)]*?name\s*=\s*['\"](.*?)['\"]", setup_py_content, @@ -321,11 +321,10 @@ def check_for_toml_or_setup_file() -> Optional[str]: new_pyproject_toml = tomlkit.document() new_pyproject_toml["tool"] = {"codeflash": {}} try: - with open(pyproject_toml_path, "w", encoding="utf8") as pyproject_file: - pyproject_file.write(tomlkit.dumps(new_pyproject_toml)) + pyproject_toml_path.write_text(tomlkit.dumps(new_pyproject_toml), encoding="utf8") # Check if the pyproject.toml file was created - if os.path.exists(pyproject_toml_path): + if pyproject_toml_path.exists(): click.echo( f"✅ Created a pyproject.toml file at {pyproject_toml_path}", ) @@ -356,9 +355,9 @@ def install_github_actions() -> None: owner, repo_name = get_repo_owner_and_name(repo) require_github_app_or_exit(owner, repo_name) - git_root = repo.git.rev_parse("--show-toplevel") - workflows_path = os.path.join(git_root, ".github", "workflows") - optimize_yaml_path = os.path.join(workflows_path, "codeflash-optimize.yaml") + git_root = Path(repo.git.rev_parse("--show-toplevel")) + workflows_path = git_root / ".github" / "workflows" + optimize_yaml_path = workflows_path / "codeflash-optimize.yaml" confirm_creation_yes = inquirer_wrapper( inquirer.confirm, @@ -373,7 +372,7 @@ def install_github_actions() -> None: click.echo("⏩️ Exiting workflow creation.") ph("cli-github-workflow-skipped") apologize_and_exit() - os.makedirs(workflows_path, exist_ok=True) + workflows_path.mkdir(parents=True, exist_ok=True) from importlib.resources import files py_version = sys.version_info @@ -387,13 +386,13 @@ def install_github_actions() -> None: "{{ python_version }}", python_version_string, ) - with open(optimize_yaml_path, "w", encoding="utf8") as optimize_yml_file: + with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file: optimize_yml_file.write(optimize_yml_content) click.echo(f"✅ Created {optimize_yaml_path}{LF}") click.prompt( f"Next, you'll need to add your CODEFLASH_API_KEY as a secret to your GitHub repo.{LF}" - + f"Press Enter to open your repo's secrets page at {get_github_secrets_page_url(repo)} ...{LF}" - + f"Then, click 'New repository secret' to add your api key with the variable name CODEFLASH_API_KEY.{LF}", + f"Press Enter to open your repo's secrets page at {get_github_secrets_page_url(repo)} ...{LF}" + f"Then, click 'New repository secret' to add your api key with the variable name CODEFLASH_API_KEY.{LF}", default="", type=click.STRING, prompt_suffix="", @@ -402,7 +401,7 @@ def install_github_actions() -> None: click.launch(get_github_secrets_page_url(repo)) click.echo( "🐙 I opened your Github secrets page! Note: if you see a 404, you probably don't have access to this " - + "repo's secrets; ask a repo admin to add it for you, or (not super recommended) you can temporarily " + "repo's secrets; ask a repo admin to add it for you, or (not super recommended) you can temporarily " f"hard-code your api key into the workflow file.{LF}", ) click.pause() @@ -410,7 +409,7 @@ def install_github_actions() -> None: click.prompt( f"Finally, for the workflow to work, you'll need to edit the workflow file to install the right " f"Python version and any project dependencies.{LF}" - + f"Press Enter to open {optimize_yaml_path} in your editor.{LF}", + f"Press Enter to open {optimize_yaml_path} in your editor.{LF}", default="", type=click.STRING, prompt_suffix="", @@ -419,7 +418,7 @@ def install_github_actions() -> None: click.launch(optimize_yaml_path) click.echo( "📝 I opened the workflow file in your editor! You'll need to edit the steps that install the right Python " - + f"version and any project dependencies. See the comments in the file for more details.{LF}", + f"version and any project dependencies. See the comments in the file for more details.{LF}", ) click.pause() click.echo() @@ -434,9 +433,9 @@ def install_github_actions() -> None: # Create or update the pyproject.toml file with the Codeflash dependency & configuration def configure_pyproject_toml(setup_info: SetupInfo) -> None: - toml_path = os.path.join(os.getcwd(), "pyproject.toml") + toml_path = Path.cwd() / "pyproject.toml" try: - with open(toml_path, encoding="utf8") as pyproject_file: + with toml_path.open(encoding="utf8") as pyproject_file: pyproject_data = tomlkit.parse(pyproject_file.read()) except FileNotFoundError: click.echo( @@ -470,7 +469,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None: pyproject_data["tool"] = tool_section click.echo("Writing Codeflash configuration ...\r", nl=False) - with open(toml_path, "w", encoding="utf8") as pyproject_file: + with toml_path.open("w", encoding="utf8") as pyproject_file: pyproject_file.write(tomlkit.dumps(pyproject_data)) click.echo(f"✅ Added Codeflash configuration to {toml_path}") click.echo() @@ -488,8 +487,8 @@ def install_github_app() -> None: else: click.prompt( f"Finally, you'll need install the Codeflash GitHub app by choosing the repository you want to install Codeflash on.{LF}" - + f"I will attempt to open the github app page - https://github.com/apps/codeflash-ai/installations/select_target {LF}" - + f"Press Enter to open the page to let you install the app ...{LF}", + f"I will attempt to open the github app page - https://github.com/apps/codeflash-ai/installations/select_target {LF}" + f"Press Enter to open the page to let you install the app ...{LF}", default="", type=click.STRING, prompt_suffix="", @@ -601,7 +600,7 @@ def enter_api_key_and_save_to_rc() -> None: os.environ["CODEFLASH_API_KEY"] = api_key -def create_bubble_sort_file_and_test(args: Namespace) -> None: +def create_bubble_sort_file_and_test(args: Namespace) -> tuple[str, str]: bubble_sort_content = """def sorter(arr): for i in range(len(arr)): for j in range(len(arr) - 1): @@ -613,10 +612,10 @@ def create_bubble_sort_file_and_test(args: Namespace) -> None: """ if args.test_framework == "unittest": bubble_sort_test_content = f"""import unittest -from {os.path.basename(args.module_root)}.bubble_sort import sorter + from {Path(args.module_root).name}.bubble_sort import sorter -class TestBubbleSort(unittest.TestCase): - def test_sort(self): + class TestBubbleSort(unittest.TestCase): + def test_sort(self): input = [5, 4, 3, 2, 1, 0] output = sorter(input) self.assertEqual(output, [0, 1, 2, 3, 4, 5]) @@ -628,9 +627,9 @@ class TestBubbleSort(unittest.TestCase): input = list(reversed(range(100))) output = sorter(input) self.assertEqual(output, list(range(100))) -""" + """ elif args.test_framework == "pytest": - bubble_sort_test_content = f"""from {os.path.basename(args.module_root)}.bubble_sort import sorter + bubble_sort_test_content = f"""from {Path(args.module_root).name}.bubble_sort import sorter def test_sort(): input = [5, 4, 3, 2, 1, 0] @@ -648,15 +647,17 @@ def test_sort(): else: click.echo(f"❌ Unsupported test framework: {args.test_framework}") apologize_and_exit() - bubble_sort_path = os.path.join(args.module_root, "bubble_sort.py") - with open(bubble_sort_path, "w", encoding="utf8") as bubble_sort_file: - bubble_sort_file.write(bubble_sort_content) - bubble_sort_test_path = os.path.join(args.tests_root, "test_bubble_sort.py") - with open(bubble_sort_test_path, "w", encoding="utf8") as bubble_sort_test_file: - bubble_sort_test_file.write(bubble_sort_test_content) + + bubble_sort_path = Path(args.module_root) / "bubble_sort.py" + bubble_sort_path.write_text(bubble_sort_content, encoding="utf8") + + bubble_sort_test_path = Path(args.tests_root) / "test_bubble_sort.py" + bubble_sort_test_path.write_text(bubble_sort_test_content, encoding="utf8") + click.echo(f"✅ Created {bubble_sort_path}") click.echo(f"✅ Created {bubble_sort_test_path}") - return bubble_sort_path, bubble_sort_test_path + + return str(bubble_sort_path), str(bubble_sort_test_path) def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test_path: str) -> None: diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 1b5951209..b66d48e44 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +from pathlib import Path from typing import TYPE_CHECKING import libcst as cst @@ -9,8 +10,8 @@ from libcst.codemod import CodemodContext from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor from libcst.helpers import calculate_module_and_package -from codeflash.discovery.functions_to_optimize import FunctionParent from codeflash.cli_cmds.console import logger +from codeflash.discovery.functions_to_optimize import FunctionParent if TYPE_CHECKING: from libcst.helpers import ModuleNameAndPackage @@ -43,9 +44,9 @@ def delete___future___aliased_imports(module_code: str) -> str: def add_needed_imports_from_module( src_module_code: str, dst_module_code: str, - src_path: str, - dst_path: str, - project_root: str, + src_path: Path, + dst_path: Path, + project_root: Path, helper_functions: list[FunctionSource] | None = None, ) -> str: """Add all needed and used source module code imports to the destination module code, and return it.""" @@ -57,13 +58,13 @@ def add_needed_imports_from_module( dst_module_and_package: ModuleNameAndPackage = calculate_module_and_package(project_root, dst_path) dst_context: CodemodContext = CodemodContext( - filename=src_path, + filename=src_path.name, full_module_name=dst_module_and_package.name, full_package_name=dst_module_and_package.package, ) gatherer: GatherImportsVisitor = GatherImportsVisitor( CodemodContext( - filename=src_path, + filename=src_path.name, full_module_name=src_module_and_package.name, full_package_name=src_module_and_package.package, ), @@ -129,7 +130,7 @@ def get_code( ): return None, set() - file_path: str = functions_to_optimize[0].file_path + file_path: Path = functions_to_optimize[0].file_path class_skeleton: set[tuple[int, int | None]] = set() contextual_dunder_methods: set[tuple[str, str]] = set() target_code: str = "" diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 8694ff11d..dad68de5f 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -1,7 +1,7 @@ from __future__ import annotations import ast -from typing import IO +from typing import IO, TYPE_CHECKING import libcst as cst from libcst import FunctionDef @@ -9,6 +9,9 @@ from libcst import FunctionDef from codeflash.code_utils.code_extractor import add_needed_imports_from_module from codeflash.discovery.functions_to_optimize import FunctionParent +if TYPE_CHECKING: + from pathlib import Path + class OptimFunctionCollector(cst.CSTVisitor): METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,) @@ -200,11 +203,11 @@ def replace_functions_and_add_imports( source_code: str, function_names: list[str], optimized_code: str, - file_path_of_module_with_function_to_optimize: str, - module_abspath: str, + file_path_of_module_with_function_to_optimize: Path, + module_abspath: Path, preexisting_objects: list[tuple[str, list[FunctionParent]]], contextual_functions: set[tuple[str, str]], - project_root_path: str, + project_root_path: Path, ) -> str: return add_needed_imports_from_module( optimized_code, @@ -224,11 +227,11 @@ def replace_functions_and_add_imports( def replace_function_definitions_in_module( function_names: list[str], optimized_code: str, - file_path_of_module_with_function_to_optimize: str, - module_abspath: str, + file_path_of_module_with_function_to_optimize: Path, + module_abspath: Path, preexisting_objects: list[tuple[str, list[FunctionParent]]], contextual_functions: set[tuple[str, str]], - project_root_path: str, + project_root_path: Path, ) -> bool: """:param function_names: List of qualified (not fully qualified) function names (function_name or class_name.method_name). @@ -241,8 +244,7 @@ def replace_function_definitions_in_module( :return: """ file: IO[str] - with open(module_abspath, encoding="utf8") as file: - source_code: str = file.read() + source_code: str = module_abspath.read_text(encoding="utf8") new_code: str = replace_functions_and_add_imports( source_code, function_names, @@ -255,8 +257,7 @@ def replace_function_definitions_in_module( ) if is_zero_diff(source_code, new_code): return False - with open(module_abspath, "w", encoding="utf8") as file: - file.write(new_code) + module_abspath.write_text(new_code, encoding="utf8") return True diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index ca411375e..02939fe41 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -1,27 +1,26 @@ from __future__ import annotations import ast -from codeflash.cli_cmds.console import logger import os import site from tempfile import TemporaryDirectory - -def module_name_from_file_path(file_path: str, project_root_path: str) -> str: - relative_path = os.path.relpath(file_path, project_root_path) - module_path = relative_path.replace(os.sep, ".") - if module_path.lower().endswith(".py"): - module_path = module_path[:-3] - return module_path +from codeflash.cli_cmds.console import logger +from pathlib import Path -def file_path_from_module_name(module_name: str, project_root_path: str) -> str: - """Get file path from module path""" - return os.path.join(project_root_path, module_name.replace(".", os.sep) + ".py") +def module_name_from_file_path(file_path: Path, project_root_path: Path) -> str: + relative_path = file_path.relative_to(project_root_path) + return relative_path.with_suffix("").as_posix().replace("/", ".") + + +def file_path_from_module_name(module_name: str, project_root_path: Path) -> Path: + """Get file path from module path.""" + return project_root_path / (module_name.replace(".", os.sep) + ".py") def get_imports_from_file( - file_path: str | None = None, + file_path: Path | None = None, file_string: str | None = None, file_ast: ast.AST | None = None, ) -> list[ast.Import | ast.ImportFrom]: @@ -29,19 +28,18 @@ def get_imports_from_file( sum([file_path is not None, file_string is not None, file_ast is not None]) == 1 ), "Must provide exactly one of file_path, file_string, or file_ast" if file_path: - with open(file_path, encoding="utf8") as file: + with file_path.open(encoding="utf8") as file: file_string = file.read() if file_ast is None: + if file_string is None: + logger.error("file_string cannot be None when file_ast is not provided") + return [] try: file_ast = ast.parse(file_string) except SyntaxError as e: logger.exception(f"Syntax error in code: {e}") return [] - imports = [] - for node in ast.walk(file_ast): - if isinstance(node, (ast.Import, ast.ImportFrom)): - imports.append(node) - return imports + return [node for node in ast.walk(file_ast) if isinstance(node, (ast.Import, ast.ImportFrom))] def get_all_function_names(code: str) -> tuple[bool, list[str]]: @@ -51,34 +49,27 @@ def get_all_function_names(code: str) -> tuple[bool, list[str]]: logger.exception(f"Syntax error in code: {e}") return False, [] - function_names = [] - for node in ast.walk(module): - if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - function_names.append(node.name) + function_names = [ + node.name for node in ast.walk(module) if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + ] return True, function_names -def get_run_tmp_file(file_path: str) -> str: +def get_run_tmp_file(file_path: Path) -> Path: if not hasattr(get_run_tmp_file, "tmpdir"): get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_") - return os.path.join(get_run_tmp_file.tmpdir.name, file_path) + return Path(get_run_tmp_file.tmpdir.name) / file_path -def path_belongs_to_site_packages(file_path: str) -> bool: - site_packages = site.getsitepackages() - for site_package_path in site_packages: - if file_path.startswith(site_package_path + os.sep): - return True - return False +def path_belongs_to_site_packages(file_path: Path) -> bool: + site_packages = [Path(p) for p in site.getsitepackages()] + return any(file_path.resolve().is_relative_to(site_package_path) for site_package_path in site_packages) -def is_class_defined_in_file(class_name: str, file_path: str) -> bool: - if not os.path.exists(file_path): +def is_class_defined_in_file(class_name: str, file_path: Path) -> bool: + if not file_path.exists(): return False - with open(file_path) as file: + with file_path.open(encoding="utf8") as file: source = file.read() tree = ast.parse(source) - for node in ast.walk(tree): - if isinstance(node, ast.ClassDef) and node.name == class_name: - return True - return False + return any(isinstance(node, ast.ClassDef) and node.name == class_name for node in ast.walk(tree)) diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index 24a8f0b81..cb34ffe68 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -1,56 +1,53 @@ -import os +from __future__ import annotations + +from pathlib import Path from typing import Any import tomlkit -def find_pyproject_toml(config_file=None): +def find_pyproject_toml(config_file: Path | None = None) -> Path: # Find the pyproject.toml file on the root of the project if config_file is not None: - if not config_file.lower().endswith(".toml"): - raise ValueError( - f"Config file {config_file} is not a valid toml file. Please recheck the path to pyproject.toml", - ) - if not os.path.exists(config_file): - raise ValueError( - f"Config file {config_file} does not exist. Please recheck the path to pyproject.toml", - ) + config_file = Path(config_file) + if config_file.suffix.lower() != ".toml": + msg = f"Config file {config_file} is not a valid toml file. Please recheck the path to pyproject.toml" + raise ValueError(msg) + if not config_file.exists(): + msg = f"Config file {config_file} does not exist. Please recheck the path to pyproject.toml" + raise ValueError(msg) return config_file - else: - dir_path = os.getcwd() + dir_path = Path.cwd() - while os.path.dirname(dir_path) != dir_path: - config_file = os.path.join(dir_path, "pyproject.toml") - if os.path.exists(config_file): + while dir_path != dir_path.parent: + config_file = dir_path / "pyproject.toml" + if config_file.exists(): return config_file # Search for pyproject.toml in the parent directories - dir_path = os.path.dirname(dir_path) - raise ValueError( - f"Could not find pyproject.toml in the current directory {os.getcwd()} or any of the parent directories. Please create it by running `poetry init`, or pass the path to pyproject.toml with the --config-file argument.", - ) + dir_path = dir_path.parent + msg = f"Could not find pyproject.toml in the current directory {Path.cwd()} or any of the parent directories. Please create it by running `poetry init`, or pass the path to pyproject.toml with the --config-file argument." + + raise ValueError(msg) -def parse_config_file(config_file_path: str = None) -> tuple[dict[str, Any], str]: +def parse_config_file(config_file_path: Path | None = None) -> tuple[dict[str, Any], Path]: config_file_path = find_pyproject_toml(config_file_path) try: - with open(config_file_path, "rb") as f: + with config_file_path.open("rb") as f: data = tomlkit.parse(f.read()) except tomlkit.exceptions.ParseError as e: - raise ValueError( - f"Error while parsing the config file {config_file_path}. Please recheck the file for syntax errors. Error: {e}", - ) + msg = f"Error while parsing the config file {config_file_path}. Please recheck the file for syntax errors. Error: {e}" + raise ValueError(msg) from e try: tool = data["tool"] assert isinstance(tool, dict) config = tool["codeflash"] - except tomlkit.exceptions.NonExistentKey: - raise ValueError( - f"Could not find the 'codeflash' block in the config file {config_file_path}. " - f"Please run 'codeflash init' to create the config file.", - ) + except tomlkit.exceptions.NonExistentKey as e: + msg = f"Could not find the 'codeflash' block in the config file {config_file_path}. Please run 'codeflash init' to create the config file." + raise ValueError(msg) from e assert isinstance(config, dict) # default values: @@ -79,9 +76,7 @@ def parse_config_file(config_file_path: str = None) -> tuple[dict[str, Any], str config[key] = bool_keys[key] for key in path_keys: if key in config: - config[key] = os.path.realpath( - os.path.join(os.path.dirname(config_file_path), config[key]), - ) + config[key] = str((Path(config_file_path).parent / Path(config[key])).resolve()) for key in list_str_keys: if key in config: config[key] = [str(cmd) for cmd in config[key]] @@ -90,10 +85,7 @@ def parse_config_file(config_file_path: str = None) -> tuple[dict[str, Any], str for key in path_list_keys: if key in config: - config[key] = [ - os.path.realpath(os.path.join(os.path.dirname(config_file_path), path)) - for path in config[key] - ] + config[key] = [(Path(config_file_path).parent / path).resolve() for path in config[key]] else: # Default to empty list config[key] = [] diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index b241bb6db..c58fe070d 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -1,28 +1,31 @@ -from codeflash.cli_cmds.console import logger -import os.path +from __future__ import annotations + +import os import shlex import subprocess +from pathlib import Path import isort +from codeflash.cli_cmds.console import logger + def format_code( formatter_cmds: list[str], - path: str, -) -> str: + path: Path, +) -> str | None: # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution - if not os.path.exists(path): + if not path.exists(): logger.error(f"File {path} does not exist. Cannot format the file.") return None if formatter_cmds[0].lower() == "disabled": - with open(path, encoding="utf8") as f: - new_code = f.read() + new_code = path.read_text(encoding="utf8") return new_code file_token = "$file" for command in formatter_cmds: formatter_cmd_list = shlex.split(command, posix=os.name != "nt") - formatter_cmd_list = [path if chunk == file_token else chunk for chunk in formatter_cmd_list] + formatter_cmd_list = [str(path) if chunk == file_token else chunk for chunk in formatter_cmd_list] logger.info(f"Formatting code with {' '.join(formatter_cmd_list)} ...") try: @@ -40,8 +43,7 @@ def format_code( else: logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}") - with open(path, encoding="utf8") as f: - new_code = f.read() + new_code = path.read_text(encoding="utf8") return new_code diff --git a/codeflash/code_utils/git_utils.py b/codeflash/code_utils/git_utils.py index ece8277d1..6c3a432bf 100644 --- a/codeflash/code_utils/git_utils.py +++ b/codeflash/code_utils/git_utils.py @@ -1,9 +1,9 @@ from __future__ import annotations -from codeflash.cli_cmds.console import logger -import os + import sys import time from io import StringIO +from pathlib import Path from typing import Optional import git @@ -12,10 +12,11 @@ from git import Repo from unidiff import PatchSet from codeflash.cli_cmds.cli_common import inquirer_wrapper +from codeflash.cli_cmds.console import logger def get_git_diff( - repo_directory: str = os.getcwd(), + repo_directory: Path = Path.cwd(), uncommitted_changes: bool = False, ) -> dict[str, list[int]]: repository = git.Repo(repo_directory, search_parent_directories=True) @@ -37,11 +38,12 @@ def get_git_diff( patch_set = PatchSet(StringIO(uni_diff_text)) change_list: dict[str, list[int]] = {} # list of changes for patched_file in patch_set: - file_path: str = patched_file.path # file name - if not file_path.endswith(".py"): + file_path: Path = Path(patched_file.path) + if file_path.suffix != ".py": continue - file_path = os.path.join(repository.working_dir, file_path) - logger.debug("file name :" + file_path) + file_path = Path(repository.working_dir) / file_path + logger.debug(f"file name: {file_path}") + add_line_no: list[int] = [ line.target_line_no for hunk in patched_file @@ -49,15 +51,16 @@ def get_git_diff( if line.is_added and line.value.strip() != "" ] # the row number of deleted lines - logger.debug("added lines : " + str(add_line_no)) + logger.debug(f"added lines: {add_line_no}") + del_line_no: list[int] = [ line.source_line_no for hunk in patched_file for line in hunk if line.is_removed and line.value.strip() != "" - ] # the row number of added liens + ] # the row number of added lines - logger.debug("deleted lines : " + str(del_line_no)) + logger.debug(f"deleted lines: {del_line_no}") change_list[file_path] = add_line_no return change_list diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 5be146243..d75001f1f 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -1,15 +1,16 @@ -from codeflash.cli_cmds.console import logger import os import re import shlex import unittest from collections import defaultdict from multiprocessing import Process, Queue +from pathlib import Path from typing import Dict, List, Optional, Tuple import jedi from pydantic.dataclasses import dataclass +from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import module_name_from_file_path from codeflash.verification.test_results import TestType from codeflash.verification.verification_utils import TestConfig @@ -139,7 +140,7 @@ def discover_tests_unittest( cfg: TestConfig, discover_only_these_tests: Optional[List[str]] = None, ) -> Dict[str, List[FunctionCalledInTest]]: - tests_root = cfg.tests_root + tests_root = Path(cfg.tests_root) loader = unittest.TestLoader() tests = loader.discover(str(tests_root)) file_to_test_map = defaultdict(list) @@ -151,18 +152,18 @@ def discover_tests_unittest( _test.__class__.__qualname__, ) - _test_module_path = _test_module.replace(".", os.sep) - _test_module_path = os.path.normpath(os.path.join(str(tests_root), _test_module_path) + ".py") - if not os.path.exists(_test_module_path) or ( - discover_only_these_tests and _test_module_path not in discover_only_these_tests + _test_module_path = Path(_test_module.replace(".", os.sep)).with_suffix(".py") + _test_module_path = tests_root / _test_module_path + if not _test_module_path.exists() or ( + discover_only_these_tests and str(_test_module_path) not in discover_only_these_tests ): return None - if "__replay_test" in _test_module_path: + if "__replay_test" in str(_test_module_path): test_type = TestType.REPLAY_TEST else: test_type = TestType.EXISTING_UNIT_TEST return TestsInFile( - test_file=_test_module_path, + test_file=str(_test_module_path), test_suite=_test_suite_name, test_function=_test_function, test_type=test_type, @@ -186,11 +187,11 @@ def discover_tests_unittest( continue details = get_test_details(test_2) if details is not None: - file_to_test_map[details.test_file].append(details) + file_to_test_map[str(details.test_file)].append(details) else: details = get_test_details(test) if details is not None: - file_to_test_map[details.test_file].append(details) + file_to_test_map[str(details.test_file)].append(details) return process_test_files(file_to_test_map, cfg) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 3346d1c86..670b125f6 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -128,7 +128,7 @@ class FunctionToOptimize: """ function_name: str - file_path: str + file_path: Path parents: list[FunctionParent] # list[ClassDef | FunctionDef | AsyncFunctionDef] starting_line: Optional[int] = None ending_line: Optional[int] = None @@ -150,20 +150,20 @@ class FunctionToOptimize: def qualified_name(self) -> str: return self.function_name if self.parents == [] else f"{self.parents[0].name}.{self.function_name}" - def qualified_name_with_modules_from_root(self, project_root_path: str) -> str: + def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str: return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}" def get_functions_to_optimize( optimize_all: str | None, replay_test: str | None, - file: str | None, + file: Path | None, only_get_this_function: str | None, test_cfg: TestConfig, - ignore_paths: list[str], - project_root: str, - module_root: str, -) -> tuple[dict[str, list[FunctionToOptimize]], int]: + ignore_paths: list[Path], + project_root: Path, + module_root: Path, +) -> tuple[dict[Path, list[FunctionToOptimize]], int]: assert ( sum( [ # Ensure only one of the options is provided @@ -191,9 +191,8 @@ def get_functions_to_optimize( if only_get_this_function is not None: split_function = only_get_this_function.split(".") if len(split_function) > 2: - raise ValueError( - "Function name should be in the format 'function_name' or 'class_name.function_name'", - ) + msg = "Function name should be in the format 'function_name' or 'class_name.function_name'" + raise ValueError(msg) if len(split_function) == 2: class_name, only_function_name = split_function else: @@ -206,10 +205,8 @@ def get_functions_to_optimize( ): found_function = fn if found_function is None: - raise ValueError( - f"Function {only_function_name} not found in file {file} or" - f" the function does not have a 'return' statement.", - ) + msg = f"Function {only_function_name} not found in file {file} or the function does not have a 'return' statement." + raise ValueError(msg) functions[file] = [found_function] else: logger.info("Finding all functions modified in the current git diff ...") @@ -229,64 +226,60 @@ def get_functions_to_optimize( def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]: modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=False) modified_functions: dict[str, list[FunctionToOptimize]] = {} - for path in modified_lines: - if not os.path.exists(path): + for path_str in modified_lines: + path = Path(path_str) + if not path.exists(): continue - with open(path, encoding="utf8") as f: + with path.open(encoding="utf8") as f: file_content = f.read() try: wrapper = cst.metadata.MetadataWrapper(cst.parse_module(file_content)) except Exception as e: logger.exception(e) continue - function_lines = FunctionVisitor(file_path=path) + function_lines = FunctionVisitor(file_path=str(path)) wrapper.visit(function_lines) - modified_functions[path] = [ + modified_functions[str(path)] = [ function_to_optimize for function_to_optimize in function_lines.functions if (start_line := function_to_optimize.starting_line) is not None and (end_line := function_to_optimize.ending_line) is not None - and any(start_line <= line <= end_line for line in modified_lines[path]) + and any(start_line <= line <= end_line for line in modified_lines[path_str]) ] return modified_functions -def get_all_files_and_functions(module_root_path: str) -> dict[str, list[FunctionToOptimize]]: +def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[FunctionToOptimize]]: functions: dict[str, list[FunctionToOptimize]] = {} - for root, dirs, files in os.walk(module_root_path): - for file in files: - if not file.endswith(".py"): - continue - file_path = os.path.join(root, file) - - # Find all the functions in the file - functions.update(find_all_functions_in_file(file_path)) + module_root_path = Path(module_root_path) + for file_path in module_root_path.rglob("*.py"): + # Find all the functions in the file + functions.update(find_all_functions_in_file(str(file_path))) # Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time. # Helpful if an optimize-all run is stuck and we restart it. files_list = list(functions.items()) random.shuffle(files_list) - functions_shuffled = dict(files_list) - return functions_shuffled + return dict(files_list) -def find_all_functions_in_file(file_path: str) -> dict[str, list[FunctionToOptimize]]: +def find_all_functions_in_file(file_path: Path) -> dict[str, list[FunctionToOptimize]]: functions: dict[str, list[FunctionToOptimize]] = {} - with open(file_path, encoding="utf8") as f: + with file_path.open(encoding="utf8") as f: try: ast_module = ast.parse(f.read()) except Exception as e: logger.exception(e) return functions - function_name_visitor = FunctionWithReturnStatement(file_path) + function_name_visitor = FunctionWithReturnStatement(str(file_path)) function_name_visitor.visit(ast_module) - functions[file_path] = function_name_visitor.functions + functions[str(file_path)] = function_name_visitor.functions return functions def get_all_replay_test_functions( replay_test: str, test_cfg: TestConfig, - project_root_path: str, + project_root_path: Path, ) -> dict[str, list[FunctionToOptimize]]: function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test]) # Get the absolute file paths for each function, excluding class name if present @@ -303,7 +296,7 @@ def get_all_replay_test_functions( if module_path_parts and is_class_defined_in_file( module_path_parts[-1], - os.path.join(project_root_path, *module_path_parts[:-1]) + ".py", + str(Path(project_root_path, *module_path_parts[:-1])) + ".py", ) else None ) @@ -314,7 +307,7 @@ def get_all_replay_test_functions( else: function = function_name file_path_parts = module_path_parts - file_path = os.path.join(project_root_path, *file_path_parts) + ".py" + file_path = Path(project_root_path, *file_path_parts).with_suffix(".py") file_to_functions_map[file_path].append((function, function_name, class_name)) for file_path, functions in file_to_functions_map.items(): all_valid_functions: dict[str, list[FunctionToOptimize]] = find_all_functions_in_file( @@ -349,8 +342,7 @@ def ignored_submodule_paths(module_root: str) -> list[str]: if is_git_repo(module_root): git_repo = git.Repo(module_root, search_parent_directories=True) return [ - os.path.realpath(os.path.join(git_repo.working_tree_dir, submodule.path)) - for submodule in git_repo.submodules + Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules ] return [] @@ -358,7 +350,7 @@ def ignored_submodule_paths(module_root: str) -> list[str]: class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor): def __init__( self, - file_name: str, + file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None, @@ -418,7 +410,7 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor): def inspect_top_level_functions_or_methods( - file_name: str, + file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None, @@ -449,11 +441,11 @@ def inspect_top_level_functions_or_methods( def filter_functions( modified_functions: dict[str, list[FunctionToOptimize]], tests_root: str, - ignore_paths: list[str], + ignore_paths: list[Path], project_root: str, module_root: str, disable_logs: bool = False, -) -> tuple[dict[str, list[FunctionToOptimize]], int]: +) -> tuple[dict[Path, list[FunctionToOptimize]], int]: blocklist_funcs = get_blacklisted_functions() # Remove any function that we don't want to optimize @@ -522,37 +514,33 @@ def filter_functions( def filter_files_optimized( - file_path: str, - tests_root: str, - ignore_paths: list[str], - module_root: str, + file_path: Path, + tests_root: Path, + ignore_paths: list[Path], + module_root: Path, ) -> bool: """Optimized version of the filter_functions function above. + Takes in file paths and returns the count of files that are to be optimized. """ submodule_paths = None - if file_path.startswith(tests_root + os.sep): + if file_path.is_relative_to(tests_root): return False if file_path in ignore_paths or any( - file_path.startswith(ignore_path + os.sep) for ignore_path in ignore_paths + file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths ): return False if path_belongs_to_site_packages(file_path): return False - if not file_path.startswith(module_root + os.sep): + if not file_path.is_relative_to(module_root): return False if submodule_paths is None: submodule_paths = ignored_submodule_paths(module_root) - if file_path in submodule_paths or any( - file_path.startswith(submodule_path + os.sep) for submodule_path in submodule_paths - ): - return False - - return True + return not ( + file_path in submodule_paths + or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths) + ) -def function_has_return_statement(function_node: Union[FunctionDef, AsyncFunctionDef]) -> bool: - for node in ast.walk(function_node): - if isinstance(node, ast.Return): - return True - return False +def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool: + return any(isinstance(node, ast.Return) for node in ast.walk(function_node)) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index a77a78552..9e1716d05 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -1,5 +1,6 @@ from __future__ import annotations +from pathlib import Path from typing import Any, Generator, Iterator, Optional from jedi.api.classes import Name @@ -17,7 +18,7 @@ from codeflash.verification.test_results import TestResults, TestType @dataclass(frozen=True, config={"arbitrary_types_allowed": True}) class FunctionSource: - file_path: str + file_path: Path qualified_name: str fully_qualified_name: str only_function_name: str @@ -55,7 +56,7 @@ class GeneratedTestsList(BaseModel): class TestFile(BaseModel): - instrumented_file_path: str + instrumented_file_path: Path original_file_path: Optional[str] = None original_source: Optional[str] = None test_type: TestType diff --git a/codeflash/optimization/function_context.py b/codeflash/optimization/function_context.py index b0814b1f3..e1faa83e8 100644 --- a/codeflash/optimization/function_context.py +++ b/codeflash/optimization/function_context.py @@ -4,16 +4,20 @@ import ast import os import re from collections import defaultdict +from typing import TYPE_CHECKING import jedi import tiktoken from jedi.api.classes import Name +from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_extractor import get_code from codeflash.code_utils.code_utils import module_name_from_file_path, path_belongs_to_site_packages from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize from codeflash.models.models import FunctionSource -from codeflash.cli_cmds.console import logger + +if TYPE_CHECKING: + from pathlib import Path def belongs_to_class(name: Name, class_name: str) -> bool: @@ -36,12 +40,11 @@ def belongs_to_function(name: Name, function_name: str) -> bool: def get_type_annotation_context( function: FunctionToOptimize, jedi_script: jedi.Script, - project_root_path: str, + project_root_path: Path, ) -> list[FunctionSource]: function_name: str = function.function_name - file_path: str = function.file_path - with open(file_path, encoding="utf8") as file: - file_contents: str = file.read() + file_path: Path = function.file_path + file_contents: str = file_path.read_text(encoding="utf8") try: module: ast.Module = ast.parse(file_contents) except SyntaxError as e: @@ -72,14 +75,15 @@ def get_type_annotation_context( logger.exception(f"Error while getting definition: {ex}") definition = [] if definition: # TODO can be multiple definitions - definition_path = str(definition[0].module_path) + definition_path = definition[0].module_path # The definition is part of this project and not defined within the original function if ( - definition_path.startswith(project_root_path + os.sep) + str(definition_path).startswith(str(project_root_path) + os.sep) and definition[0].full_name and not path_belongs_to_site_packages(definition_path) and not belongs_to_function(definition[0], function_name) ): + assert definition_path is not None source_code = get_code( [ FunctionToOptimize( @@ -164,7 +168,7 @@ def get_type_annotation_context( def get_function_variables_definitions( function_to_optimize: FunctionToOptimize, - project_root_path: str, + project_root_path: Path, ) -> tuple[list[FunctionSource], set[tuple[str, str]]]: function_name = function_to_optimize.function_name file_path = function_to_optimize.file_path @@ -203,10 +207,11 @@ def get_function_variables_definitions( if definitions: # TODO: there can be multiple definitions, see how to handle such cases definition = definitions[0] - definition_path = str(definition.module_path) + definition_path = definition.module_path + assert definition_path is not None # The definition is part of this project and not defined within the original function if ( - definition_path.startswith(project_root_path + os.sep) + str(definition_path).startswith(str(project_root_path) + os.sep) and not path_belongs_to_site_packages(definition_path) and definition.full_name and not belongs_to_function(definition, function_name) @@ -255,7 +260,7 @@ def get_function_variables_definitions( for source in sources: if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names: if not source.qualified_name.count("."): - no_parent_sources[source.file_path][source.qualified_name].add(source) + no_parent_sources[str(source.file_path)][source.qualified_name].add(source) else: parent_sources.add(source) existing_fully_qualified_names.add(fully_qualified_name) @@ -263,7 +268,7 @@ def get_function_variables_definitions( source for source in parent_sources if source.file_path not in no_parent_sources - or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path] + or source.qualified_name.rpartition(".")[0] not in no_parent_sources[str(source.file_path)] ] deduped_no_parent_sources = [ source @@ -279,7 +284,7 @@ MAX_PROMPT_TOKENS = 4096 # 128000 # gpt-4-128k def get_constrained_function_context_and_helper_functions( function_to_optimize: FunctionToOptimize, - project_root_path: str, + project_root_path: Path, code_to_optimize: str, max_tokens: int = MAX_PROMPT_TOKENS, ) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]: diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 30b7ced0a..76c3107e8 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -2,11 +2,11 @@ from __future__ import annotations import concurrent.futures import os -import pathlib import subprocess import time import uuid from collections import defaultdict +from pathlib import Path from typing import TYPE_CHECKING import isort @@ -623,7 +623,7 @@ class Optimizer: def replace_function_and_helpers_with_optimized_code( self, code_context: CodeOptimizationContext, - function_to_optimize_file_path: str, + function_to_optimize_file_path: Path, optimized_code: str, qualified_function_name: str, ) -> bool: @@ -660,7 +660,7 @@ class Optimizer: def get_code_optimization_context( self, function_to_optimize: FunctionToOptimize, - project_root: str, + project_root: Path, original_source_code: str, ) -> Result[CodeOptimizationContext, str]: code_to_optimize, contextual_dunder_methods = extract_code( @@ -687,7 +687,7 @@ class Optimizer: optimizable_methods = [ FunctionToOptimize( df.qualified_name.split(".")[-1], - df.file_path, + Path(df.file_path), [FunctionParent(df.qualified_name.split(".")[0], "ClassDef")], None, None, @@ -730,18 +730,14 @@ class Optimizer: @staticmethod def cleanup_leftover_test_return_values() -> None: # remove leftovers from previous run - pathlib.Path(get_run_tmp_file("test_return_values_0.bin")).unlink( - missing_ok=True, - ) - pathlib.Path(get_run_tmp_file("test_return_values_0.sqlite")).unlink( - missing_ok=True, - ) + get_run_tmp_file(Path("test_return_values_0.bin")).unlink(missing_ok=True) + get_run_tmp_file(Path("test_return_values_0.sqlite")).unlink(missing_ok=True) def instrument_existing_tests( self, function_to_optimize: FunctionToOptimize, function_to_tests: dict[str, list[TestsInFile]], - ) -> set[str]: + ) -> set[Path]: relevant_test_files_count = 0 unique_instrumented_test_files = set() @@ -767,10 +763,8 @@ class Optimizer: ) if not success: continue - new_test_path = ( - f"{os.path.splitext(test_file)[0]}__perfinstrumented{os.path.splitext(test_file)[1]}" - ) - with pathlib.Path(new_test_path).open("w", encoding="utf8") as f: + new_test_path = Path(test_file).with_suffix(f"__perfinstrumented{Path(test_file).suffix}") + with new_test_path.open("w", encoding="utf8") as f: f.write(injected_test) unique_instrumented_test_files.add(new_test_path) if not self.test_files.get_by_original_file_path(test_file): @@ -793,7 +787,7 @@ class Optimizer: code_to_optimize_with_helpers: str, function_to_optimize: FunctionToOptimize, helper_functions: list[FunctionSource], - module_path: str, + module_path: Path, function_trace_id: str, run_experiment: bool = False, ) -> Result[tuple[GeneratedTestsList, OptimizationSet], str]: @@ -995,12 +989,12 @@ class Optimizer: first_test_types = [] first_test_functions = [] - pathlib.Path( - get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin"), - ).unlink(missing_ok=True) - pathlib.Path( - get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite"), - ).unlink(missing_ok=True) + Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin")).unlink( + missing_ok=True, + ) + Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink( + missing_ok=True, + ) for test_file in instrumented_unittests_created_for_function: relevant_tests_in_file = [ @@ -1069,12 +1063,10 @@ class Optimizer: ) if best_runtime_until_now is None or total_candidate_timing < best_runtime_until_now: best_test_results = candidate_results - pathlib.Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin")).unlink( + Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin")).unlink( missing_ok=True, ) - pathlib.Path( - get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite"), - ).unlink( + Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink( missing_ok=True, ) if not equal_results: @@ -1142,7 +1134,7 @@ class Optimizer: source_code_being_tested: str, function_to_optimize: FunctionToOptimize, helper_function_names: list[str], - module_path: str, + module_path: Path, function_trace_id: str, ) -> GeneratedTestsList | None: futures = [ diff --git a/codeflash/tracer.py b/codeflash/tracer.py index d9d6090c9..616db3916 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -9,6 +9,8 @@ # Licensed under the Apache License, Version 2.0 (the "License"). # http://www.apache.org/licenses/LICENSE-2.0 # +from __future__ import annotations + import importlib.machinery import io import json @@ -23,20 +25,21 @@ import time from collections import defaultdict from copy import copy from io import StringIO +from pathlib import Path from types import FrameType -from typing import Any, List, Optional +from typing import Any, ClassVar, List, Optional import dill import isort from codeflash.cli_cmds.cli import project_root_from_module_root +from codeflash.cli_cmds.console import console from codeflash.code_utils.code_utils import module_name_from_file_path from codeflash.code_utils.config_parser import parse_config_file from codeflash.discovery.functions_to_optimize import filter_files_optimized from codeflash.tracing.replay_test import create_trace_replay_test from codeflash.tracing.tracing_utils import FunctionModules from codeflash.verification.verification_utils import get_test_file_path -from codeflash.cli_cmds.console import console class Tracer: @@ -49,7 +52,7 @@ class Tracer: output: str = "codeflash.trace", functions: Optional[List[str]] = None, disable: bool = False, - config_file_path: Optional[str] = None, + config_file_path: Path | None = None, max_function_count: int = 256, timeout: Optional[int] = None, # seconds ) -> None: @@ -77,13 +80,14 @@ class Tracer: self.disable = True return self.con = None - self.output_file = os.path.abspath(output) + self.output_file = Path(output).resolve() self.functions = functions self.function_modules: List[FunctionModules] = [] self.function_count = defaultdict(int) + self.current_file_path = Path(__file__).resolve() self.ignored_qualified_functions = { - f"{os.path.realpath(__file__)}:Tracer:__exit__", - f"{os.path.realpath(__file__)}:Tracer:__enter__", + f"{self.current_file_path}:Tracer:__exit__", + f"{self.current_file_path}:Tracer:__enter__", } self.max_function_count = max_function_count self.config, found_config_path = parse_config_file(config_file_path) @@ -99,13 +103,9 @@ class Tracer: "", "", } - self.file_being_called_from: str = str( - os.path.basename( - os.path.realpath(sys._getframe().f_back.f_code.co_filename), - ).replace( - ".", - "_", - ), + + self.file_being_called_from: str = str(Path(sys._getframe().f_back.f_code.co_filename).name).replace( + ".", "_" ) assert timeout is None or timeout > 0, "Timeout should be greater than 0" @@ -175,7 +175,7 @@ class Tracer: cur.execute( "INSERT INTO pstats VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", ( - os.path.realpath(func[0]), + Path(func[0]).resolve(), func[1], func[2], func[3], @@ -245,7 +245,7 @@ class Tracer: if code.co_name in self.ignored_functions: return - if not os.path.exists(file_name): + if not Path(file_name).exists(): return if self.functions: if code.co_name not in self.functions: @@ -264,7 +264,7 @@ class Tracer: except: # someone can override the getattr method and raise an exception. I'm looking at you wrapt return - file_name = os.path.realpath(file_name) + file_name = Path(file_name).resolve() function_qualified_name = f"{file_name}:{(class_name + ':' if class_name else '')}{code.co_name}" if function_qualified_name in self.ignored_qualified_functions: return @@ -311,7 +311,7 @@ class Tracer: # We do not pickle self for __init__ to avoid recursion errors, and instead instantiate its class # directly with the rest of the arguments in the replay tests. We copy the arguments to avoid memory # leaks, bad references or side-effects when unpickling. - arguments = {k: v for k, v in arguments.items()} + arguments = dict(arguments.items()) if class_name and code.co_name == "__init__": del arguments["self"] local_vars = pickle.dumps( @@ -463,7 +463,7 @@ class Tracer: return 1 - dispatch = { + dispatch: ClassVar[dict[str, callable]] = { "call": trace_dispatch_call, "exception": trace_dispatch_exception, "return": trace_dispatch_return, @@ -648,7 +648,7 @@ def main(): # The script that we're profiling may chdir, so capture the absolute path # to the output file at startup. if args.outfile is not None: - args.outfile = os.path.abspath(args.outfile) + args.outfile = Path(args.outfile).resolve() if len(unknown_args) > 0: if args.module: @@ -661,7 +661,7 @@ def main(): } else: progname = unknown_args[0] - sys.path.insert(0, os.path.dirname(progname)) + sys.path.insert(0, str(Path(progname).parent)) with io.open_code(progname) as fp: code = compile(fp.read(), progname, "exec") spec = importlib.machinery.ModuleSpec(name="__main__", loader=None, origin=progname) diff --git a/tests/test_add_needed_imports_from_module.py b/tests/test_add_needed_imports_from_module.py index 4eb19c77a..194afdd4d 100644 --- a/tests/test_add_needed_imports_from_module.py +++ b/tests/test_add_needed_imports_from_module.py @@ -1,4 +1,6 @@ -from codeflash.code_utils.code_extractor import add_needed_imports_from_module +from pathlib import Path + +from codeflash.code_utils.code_extractor import add_needed_imports_from_module def test_add_needed_imports_from_module0() -> None: @@ -46,9 +48,9 @@ class Source: expected = """def heyjude() -> None: print("Hey Jude, don't make it bad") """ - src_path = "/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py" - dst_path = "/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py" - project_root = "/home/roger/repos/codeflash" + src_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py") + dst_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py") + project_root = Path("/home/roger/repos/codeflash") new_module = add_needed_imports_from_module( src_module, dst_module, @@ -120,9 +122,9 @@ def belongs_to_function(name: Name, function_name: str) -> bool: # The name is defined inside the function or is the function itself return f".{function_name}." in subname or f".{function_name}" == subname ''' - src_path = "/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py" - dst_path = "/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py" - project_root = "/home/roger/repos/codeflash" + src_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py") + dst_path = Path("/home/roger/repos/codeflash/cli/codeflash/optimization/function_context.py") + project_root = Path("/home/roger/repos/codeflash") new_module = add_needed_imports_from_module( src_module, dst_module, diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index fc7910e70..6c867cc8a 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -25,7 +25,7 @@ class JediDefinition: @dataclasses.dataclass class FakeFunctionSource: - file_path: str + file_path: Path qualified_name: str fully_qualified_name: str only_function_name: str @@ -81,11 +81,11 @@ print("Hello world") source_code=original_code, function_names=[function_name], optimized_code=optim_code, - file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), - module_abspath=str(Path(__file__).resolve()), + file_path_of_module_with_function_to_optimize=Path(__file__).resolve(), + module_abspath=Path(__file__).resolve(), preexisting_objects=preexisting_objects, contextual_functions=contextual_functions, - project_root_path=str(Path(__file__).resolve().parent.resolve()), + project_root_path=Path(__file__).resolve().parent.resolve(), ) assert new_code == expected @@ -145,11 +145,11 @@ print("Hello world") source_code=original_code, function_names=[function_name], optimized_code=optim_code, - file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), - module_abspath=str(Path(__file__).resolve()), + file_path_of_module_with_function_to_optimize=Path(__file__).resolve(), + module_abspath=Path(__file__).resolve(), preexisting_objects=preexisting_objects, contextual_functions=contextual_functions, - project_root_path=str(Path(__file__).resolve().parent.resolve()), + project_root_path=Path(__file__).resolve().parent.resolve(), ) assert new_code == expected @@ -206,11 +206,11 @@ print("Salut monde") source_code=original_code, function_names=function_names, optimized_code=optim_code, - file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), - module_abspath=str(Path(__file__).resolve()), + file_path_of_module_with_function_to_optimize=Path(__file__).resolve(), + module_abspath=Path(__file__).resolve(), preexisting_objects=preexisting_objects, contextual_functions=contextual_functions, - project_root_path=str(Path(__file__).resolve().parent.resolve()), + project_root_path=Path(__file__).resolve().parent.resolve(), ) assert new_code == expected @@ -270,11 +270,11 @@ print("Salut monde") source_code=original_code, function_names=function_names, optimized_code=optim_code, - file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), - module_abspath=str(Path(__file__).resolve()), + file_path_of_module_with_function_to_optimize=Path(__file__).resolve(), + module_abspath=Path(__file__).resolve(), preexisting_objects=preexisting_objects, contextual_functions=contextual_functions, - project_root_path=str(Path(__file__).resolve().parent.resolve()), + project_root_path=Path(__file__).resolve().parent.resolve(), ) assert new_code == expected @@ -326,11 +326,11 @@ def supersort(doink): source_code=original_code, function_names=function_names, optimized_code=optim_code, - file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), - module_abspath=str(Path(__file__).resolve()), + file_path_of_module_with_function_to_optimize=Path(__file__).resolve(), + module_abspath=Path(__file__).resolve(), preexisting_objects=preexisting_objects, contextual_functions=contextual_functions, - project_root_path=str(Path(__file__).resolve().parent.resolve()), + project_root_path=Path(__file__).resolve().parent.resolve(), ) assert new_code == expected @@ -402,11 +402,11 @@ print("Not cool") source_code=original_code_main, function_names=["other_function"], optimized_code=optim_code, - file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), - module_abspath=str(Path(__file__).resolve()), + file_path_of_module_with_function_to_optimize=Path(__file__).resolve(), + module_abspath=Path(__file__).resolve(), preexisting_objects=[("other_function", []), ("yet_another_function", []), ("blob", [])], contextual_functions=set(), - project_root_path=str(Path(__file__).resolve().parent.resolve()), + project_root_path=Path(__file__).resolve().parent.resolve(), ) assert new_main_code == expected_main @@ -414,11 +414,11 @@ print("Not cool") source_code=original_code_helper, function_names=["blob"], optimized_code=optim_code, - file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), - module_abspath=str(Path(__file__).resolve()), + file_path_of_module_with_function_to_optimize=Path(__file__).resolve(), + module_abspath=Path(__file__).resolve(), preexisting_objects=[], contextual_functions=set(), - project_root_path=str(Path(__file__).resolve().parent.resolve()), + project_root_path=Path(__file__).resolve().parent.resolve(), ) assert new_helper_code == expected_helper @@ -619,11 +619,11 @@ class CacheConfig(BaseConfig): source_code=original_code, function_names=function_names, optimized_code=optim_code, - file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), - module_abspath=str(Path(__file__).resolve()), + file_path_of_module_with_function_to_optimize=Path(__file__).resolve(), + module_abspath=Path(__file__).resolve(), preexisting_objects=preexisting_objects, contextual_functions=contextual_functions, - project_root_path=str(Path(__file__).resolve().parent.resolve()), + project_root_path=Path(__file__).resolve().parent.resolve(), ) assert new_code == expected @@ -694,11 +694,11 @@ def test_test_libcst_code_replacement8() -> None: source_code=original_code, function_names=function_names, optimized_code=optim_code, - file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), - module_abspath=str(Path(__file__).resolve()), + file_path_of_module_with_function_to_optimize=Path(__file__).resolve(), + module_abspath=Path(__file__).resolve(), preexisting_objects=preexisting_objects, contextual_functions=contextual_functions, - project_root_path=str(Path(__file__).resolve().parent.resolve()), + project_root_path=Path(__file__).resolve().parent.resolve(), ) assert new_code == expected @@ -757,11 +757,11 @@ print("Hello world") source_code=original_code, function_names=[function_name], optimized_code=optim_code, - file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), - module_abspath=str(Path(__file__).resolve()), + file_path_of_module_with_function_to_optimize=Path(__file__).resolve(), + module_abspath=Path(__file__).resolve(), preexisting_objects=preexisting_objects, contextual_functions=contextual_functions, - project_root_path=str(Path(__file__).resolve().parent.resolve()), + project_root_path=Path(__file__).resolve().parent.resolve(), ) assert new_code == expected @@ -817,17 +817,16 @@ class MainClass: ) func_top_optimize = FunctionToOptimize( function_name="main_method", - file_path=str(file_path), + file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")], ) - with open(file_path) as f: - original_code = f.read() - code_context = opt.get_code_optimization_context( - function_to_optimize=func_top_optimize, - project_root=str(file_path.parent), - original_source_code=original_code, - ).unwrap() - assert code_context.code_to_optimize_with_helpers == get_code_output + original_code = file_path.read_text() + code_context = opt.get_code_optimization_context( + function_to_optimize=func_top_optimize, + project_root=file_path.parent, + original_source_code=original_code, + ).unwrap() + assert code_context.code_to_optimize_with_helpers == get_code_output def test_code_replacement11() -> None: @@ -946,11 +945,11 @@ def test_test_libcst_code_replacement13() -> None: source_code=original_code, function_names=function_names, optimized_code=optim_code, - file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), - module_abspath=str(Path(__file__).resolve()), + file_path_of_module_with_function_to_optimize=Path(__file__).resolve(), + module_abspath=Path(__file__).resolve(), preexisting_objects=preexisting_objects, contextual_functions=contextual_functions, - project_root_path=str(Path(__file__).resolve().parent.resolve()), + project_root_path=Path(__file__).resolve().parent.resolve(), ) assert new_code == original_code @@ -1143,7 +1142,9 @@ class TestResults(BaseModel): helper_functions = [ FakeFunctionSource( - file_path="/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/codeflash/verification/test_results.py", + file_path=Path( + "/Users/saurabh/Library/CloudStorage/Dropbox/codeflash/cli/codeflash/verification/test_results.py" + ), qualified_name="TestType", fully_qualified_name="codeflash.verification.test_results.TestType", only_function_name="TestType", @@ -1156,11 +1157,11 @@ class TestResults(BaseModel): source_code=original_code, function_names=["TestResults.get_test_pass_fail_report_by_type"], optimized_code=optim_code, - file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), - module_abspath=str(Path(__file__).resolve()), + file_path_of_module_with_function_to_optimize=Path(__file__).resolve(), + module_abspath=Path(__file__).resolve(), preexisting_objects=preexisting_objects, contextual_functions=contextual_functions, - project_root_path=str(Path(__file__).parent.resolve()), + project_root_path=Path(__file__).parent.resolve(), ) helper_functions_by_module_abspath = defaultdict(set) @@ -1177,11 +1178,11 @@ class TestResults(BaseModel): source_code=new_code, function_names=list(qualified_names), optimized_code=optim_code, - file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), + file_path_of_module_with_function_to_optimize=Path(__file__).resolve(), module_abspath=module_abspath, preexisting_objects=preexisting_objects, contextual_functions=contextual_functions, - project_root_path=str(Path(__file__).parent.resolve()), + project_root_path=Path(__file__).parent.resolve(), ) assert ( @@ -1361,12 +1362,16 @@ def cosine_similarity_top_k( return ret_idxs, scores ''' - preexisting_objects = [("cosine_similarity_top_k", []), ("Matrix", []), ("cosine_similarity", [])] + preexisting_objects: list[tuple[str, list[FunctionParent]]] = [ + ("cosine_similarity_top_k", []), + ("Matrix", []), + ("cosine_similarity", []), + ] - contextual_functions = set() + contextual_functions: set[tuple[str, str]] = set() helper_functions = [ FakeFunctionSource( - file_path=str((Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve()), + file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(), qualified_name="Matrix", fully_qualified_name="code_to_optimize.math_utils.Matrix", only_function_name="Matrix", @@ -1374,7 +1379,7 @@ def cosine_similarity_top_k( jedi_definition=JediDefinition(type="class"), ), FakeFunctionSource( - file_path=str((Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve()), + file_path=(Path(__file__).parent / "code_to_optimize" / "math_utils.py").resolve(), qualified_name="cosine_similarity", fully_qualified_name="code_to_optimize.math_utils.cosine_similarity", only_function_name="cosine_similarity", @@ -1387,11 +1392,11 @@ def cosine_similarity_top_k( source_code=original_code, function_names=["cosine_similarity_top_k"], optimized_code=optim_code, - file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), - module_abspath=str((Path(__file__).parent / "code_to_optimize").resolve()), + file_path_of_module_with_function_to_optimize=Path(__file__).resolve(), + module_abspath=(Path(__file__).parent / "code_to_optimize").resolve(), preexisting_objects=preexisting_objects, contextual_functions=contextual_functions, - project_root_path=str(Path(__file__).parent.parent.resolve()), + project_root_path=Path(__file__).parent.parent.resolve(), ) assert ( new_code @@ -1452,11 +1457,11 @@ def cosine_similarity_top_k( source_code=new_code, function_names=list(qualified_names), optimized_code=optim_code, - file_path_of_module_with_function_to_optimize=str(Path(__file__).resolve()), + file_path_of_module_with_function_to_optimize=Path(__file__).resolve(), module_abspath=module_abspath, preexisting_objects=preexisting_objects, contextual_functions=contextual_functions, - project_root_path=str(Path(__file__).parent.parent.resolve()), + project_root_path=Path(__file__).parent.parent.resolve(), ) assert ( diff --git a/tests/test_code_utils.py b/tests/test_code_utils.py new file mode 100644 index 000000000..8d1f66241 --- /dev/null +++ b/tests/test_code_utils.py @@ -0,0 +1,92 @@ +import ast +from pathlib import Path + +import pytest + +from codeflash.code_utils.code_utils import get_imports_from_file, module_name_from_file_path + + +# tests for module_name_from_file_path +def test_module_name_from_file_path() -> None: + project_root_path = Path("/Users/codeflashuser/PycharmProjects/codeflash") + file_path = project_root_path / "cli/codeflash/code_utils/code_utils.py" + + module_name = module_name_from_file_path(file_path, project_root_path) + assert module_name == "cli.codeflash.code_utils.code_utils" + + +def test_module_name_from_file_path_with_subdirectory() -> None: + project_root_path = Path("/Users/codeflashuser/PycharmProjects/codeflash") + file_path = project_root_path / "cli/codeflash/code_utils/subdir/code_utils.py" + + module_name = module_name_from_file_path(file_path, project_root_path) + assert module_name == "cli.codeflash.code_utils.subdir.code_utils" + + +def test_module_name_from_file_path_with_different_root() -> None: + project_root_path = Path("/Users/codeflashuser/PycharmProjects") + file_path = project_root_path / "codeflash/cli/codeflash/code_utils/code_utils.py" + + module_name = module_name_from_file_path(file_path, project_root_path) + assert module_name == "codeflash.cli.codeflash.code_utils.code_utils" + + +def test_module_name_from_file_path_with_root_as_file() -> None: + project_root_path = Path("/Users/codeflashuser/PycharmProjects/codeflash/cli/codeflash/code_utils") + file_path = project_root_path / "code_utils.py" + + module_name = module_name_from_file_path(file_path, project_root_path) + assert module_name == "code_utils" + + +# def test_get_imports_from_file_with_file_path(tmp_path: Path): +# test_file = tmp_path / "test_file.py" +# test_file.write_text("import os\nfrom sys import path\n") + +# imports = get_imports_from_file(file_path=test_file) +# assert len(imports) == 2 +# assert isinstance(imports[0], ast.Import) +# assert isinstance(imports[1], ast.ImportFrom) +# assert imports[0].names[0].name == "os" +# assert imports[1].module == "sys" +# assert imports[1].names[0].name == "path" + + +# def test_get_imports_from_file_with_file_string(): +# file_string = "import os\nfrom sys import path\n" + +# imports = get_imports_from_file(file_string=file_string) +# assert len(imports) == 2 +# assert isinstance(imports[0], ast.Import) +# assert isinstance(imports[1], ast.ImportFrom) +# assert imports[0].names[0].name == "os" +# assert imports[1].module == "sys" +# assert imports[1].names[0].name == "path" + + +# def test_get_imports_from_file_with_file_ast(): +# file_string = "import os\nfrom sys import path\n" +# file_ast = ast.parse(file_string) + +# imports = get_imports_from_file(file_ast=file_ast) +# assert len(imports) == 2 +# assert isinstance(imports[0], ast.Import) +# assert isinstance(imports[1], ast.ImportFrom) +# assert imports[0].names[0].name == "os" +# assert imports[1].module == "sys" +# assert imports[1].names[0].name == "path" + + +# def test_get_imports_from_file_with_syntax_error(caplog): +# file_string = "import os\nfrom sys import path\ninvalid syntax" + +# imports = get_imports_from_file(file_string=file_string) +# assert len(imports) == 0 +# assert "Syntax error in code" in caplog.text + + +# def test_get_imports_from_file_with_no_input(): +# with pytest.raises( +# AssertionError, match="Must provide exactly one of file_path, file_string, or file_ast" +# ): +# get_imports_from_file() diff --git a/tests/test_formatter.py b/tests/test_formatter.py index e10553fb5..e27ec724b 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -1,7 +1,9 @@ import os import tempfile +from pathlib import Path import pytest + from codeflash.code_utils.config_parser import parse_config_file from codeflash.code_utils.formatter import format_code, sort_imports @@ -34,12 +36,13 @@ def test_sort_imports_without_formatting(): with tempfile.NamedTemporaryFile() as tmp: tmp.write(b"import sys\nimport unittest\nimport os\n") tmp.flush() - tmp_path = tmp.name + tmp_path = Path(tmp.name) new_code = format_code( formatter_cmds=["disabled"], path=tmp_path, ) + assert new_code is not None new_code = sort_imports(new_code) assert new_code == "import os\nimport sys\nimport unittest\n" @@ -108,7 +111,7 @@ ignore-paths = [] with tempfile.NamedTemporaryFile(suffix=".toml", delete=False) as tmp: tmp.write(config_data.encode()) tmp.flush() - tmp_path = tmp.name + tmp_path = Path(tmp.name) try: config, _ = parse_config_file(tmp_path) @@ -135,7 +138,7 @@ def foo(): actual = format_code( formatter_cmds=["black $file"], - path=tmp_path, + path=Path(tmp_path), ) assert actual == expected @@ -160,7 +163,7 @@ def foo(): actual = format_code( formatter_cmds=["black $file"], - path=tmp_path, + path=Path(tmp_path), ) assert actual == expected @@ -189,6 +192,6 @@ def foo(): actual = format_code( formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"], - path=tmp_path, + path=Path(tmp_path), ) assert actual == expected