From cd4db2291ac1921faec1ed254a3f83be733b4b64 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sat, 12 Oct 2024 20:58:44 -0500 Subject: [PATCH] ruff format --- code_to_optimize/User_post.py | 13 ++- code_to_optimize/book_catalog.py | 5 +- .../final_test_set/tests/test_gradient.py | 3 +- .../tests/test_hamming_distance.py | 3 +- .../final_test_set/tests/test_integration.py | 3 +- .../tests/test_matrix_multiplication.py | 3 +- .../tests/test_standardize_name.py | 3 +- .../test_bubble_sort_parametrized_loop.py | 3 +- .../unittest/test_bubble_sort_parametrized.py | 1 + codeflash/api/aiservice.py | 18 +++-- codeflash/api/cfapi.py | 4 +- codeflash/code_utils/env_utils.py | 2 +- codeflash/code_utils/github_utils.py | 2 +- codeflash/main.py | 3 +- codeflash/optimization/optimizer.py | 24 +++--- codeflash/result/create_pr.py | 15 ++-- codeflash/telemetry/posthog.py | 2 +- codeflash/tracing/profile_stats.py | 9 ++- codeflash/verification/comparator.py | 3 +- codeflash/verification/test_runner.py | 11 +-- codeflash/verification/verifier.py | 28 ++++--- pie_test_set/p02548.py | 11 ++- pie_test_set/p02624.py | 1 - pie_test_set/p02639.py | 6 +- pie_test_set/p02660.py | 16 ++-- pie_test_set/p02696.py | 3 +- pie_test_set/p02738.py | 1 - pie_test_set/p02782.py | 12 +-- pie_test_set/p02783.py | 2 - pie_test_set/p02786.py | 3 +- pie_test_set/p02840.py | 1 - pie_test_set/p02900.py | 2 +- pie_test_set/p02954.py | 10 ++- pie_test_set/p02957.py | 15 +++- pie_test_set/p02965.py | 17 ++-- pie_test_set/p02969.py | 3 +- pie_test_set/p02993.py | 3 +- pie_test_set/p03016.py | 25 ++---- pie_test_set/p03088.py | 3 +- pie_test_set/p03206.py | 3 +- pie_test_set/p03213.py | 3 +- pie_test_set/p03253.py | 3 +- pie_test_set/p03286.py | 3 +- pie_test_set/p03315.py | 11 +-- pie_test_set/p03502.py | 1 - pie_test_set/p03632.py | 3 +- pie_test_set/p03666.py | 9 ++- pie_test_set/p03797.py | 25 ++---- pie_test_set/p03999.py | 2 +- pie_test_set/p04019.py | 6 +- pie_test_set/p04040.py | 3 +- .../scripts/create_files_and_tests.py | 59 +++++++------- pie_test_set/scripts/run_pie_test_case.py | 8 +- tests/test_function_discovery.py | 5 +- tests/test_shell_utils.py | 6 +- tests/test_test_runner.py | 22 +++--- tests/test_unit_test_discovery.py | 79 ++++++++++--------- 57 files changed, 276 insertions(+), 267 deletions(-) diff --git a/code_to_optimize/User_post.py b/code_to_optimize/User_post.py index d310e0317..8f303bc88 100644 --- a/code_to_optimize/User_post.py +++ b/code_to_optimize/User_post.py @@ -1,9 +1,16 @@ from __future__ import annotations -from sqlalchemy import create_engine, Integer, String, ForeignKey +from sqlalchemy import ForeignKey, Integer, String, create_engine from sqlalchemy.engine.base import Engine -from sqlalchemy.orm import Mapped, mapped_column -from sqlalchemy.orm import sessionmaker, relationship, Relationship, Session, DeclarativeBase +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + Relationship, + Session, + mapped_column, + relationship, + sessionmaker, +) # Custom base class diff --git a/code_to_optimize/book_catalog.py b/code_to_optimize/book_catalog.py index d81922ade..f4a206c17 100644 --- a/code_to_optimize/book_catalog.py +++ b/code_to_optimize/book_catalog.py @@ -1,10 +1,7 @@ +from time import time from typing import List from sqlalchemy import Boolean, Column, ForeignKey, Integer, Text, func -from sqlalchemy.orm import Session, relationship -from time import time - -from sqlalchemy import Boolean, Column, ForeignKey, Integer, Text from sqlalchemy.engine import Engine, create_engine from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker from sqlalchemy.orm.relationships import Relationship diff --git a/code_to_optimize/final_test_set/tests/test_gradient.py b/code_to_optimize/final_test_set/tests/test_gradient.py index 218b8df06..1d142ee92 100644 --- a/code_to_optimize/final_test_set/tests/test_gradient.py +++ b/code_to_optimize/final_test_set/tests/test_gradient.py @@ -1,6 +1,7 @@ -from code_to_optimize.final_test_set.gradient import gradient import numpy as np +from code_to_optimize.final_test_set.gradient import gradient + def test_simple_case(): # Test case with simple values diff --git a/code_to_optimize/final_test_set/tests/test_hamming_distance.py b/code_to_optimize/final_test_set/tests/test_hamming_distance.py index 49bc1b038..6edb42a1f 100644 --- a/code_to_optimize/final_test_set/tests/test_hamming_distance.py +++ b/code_to_optimize/final_test_set/tests/test_hamming_distance.py @@ -1,6 +1,7 @@ -from code_to_optimize.final_test_set.hamming_distance import _hamming_distance import numpy as np +from code_to_optimize.final_test_set.hamming_distance import _hamming_distance + def test_no_differences(): a = np.array([1, 2, 3, 4]) diff --git a/code_to_optimize/final_test_set/tests/test_integration.py b/code_to_optimize/final_test_set/tests/test_integration.py index 9843005a6..c1bfeb492 100644 --- a/code_to_optimize/final_test_set/tests/test_integration.py +++ b/code_to_optimize/final_test_set/tests/test_integration.py @@ -1,6 +1,7 @@ -from code_to_optimize.final_test_set.integration import integrate_f import pytest +from code_to_optimize.final_test_set.integration import integrate_f + def isclose(a, b, rel_tol=1e-5, abs_tol=0.0): """ diff --git a/code_to_optimize/final_test_set/tests/test_matrix_multiplication.py b/code_to_optimize/final_test_set/tests/test_matrix_multiplication.py index fadaf29b1..01c94895d 100644 --- a/code_to_optimize/final_test_set/tests/test_matrix_multiplication.py +++ b/code_to_optimize/final_test_set/tests/test_matrix_multiplication.py @@ -1,6 +1,7 @@ -from code_to_optimize.final_test_set.matrix_multiplication import matrix_multiply import pytest +from code_to_optimize.final_test_set.matrix_multiplication import matrix_multiply + def test_matrix_multiplication_basic(): A = [[1, 2], [3, 4]] diff --git a/code_to_optimize/final_test_set/tests/test_standardize_name.py b/code_to_optimize/final_test_set/tests/test_standardize_name.py index 402b7a1d4..f35b6270c 100644 --- a/code_to_optimize/final_test_set/tests/test_standardize_name.py +++ b/code_to_optimize/final_test_set/tests/test_standardize_name.py @@ -1,6 +1,7 @@ -from code_to_optimize.final_test_set.standardize_name import standardize_name import pytest +from code_to_optimize.final_test_set.standardize_name import standardize_name + def test_exact_match(): assert standardize_name("Brattle St") == "Brattle St" diff --git a/code_to_optimize/tests/pytest/test_bubble_sort_parametrized_loop.py b/code_to_optimize/tests/pytest/test_bubble_sort_parametrized_loop.py index 5f34aa90a..00fa243ab 100644 --- a/code_to_optimize/tests/pytest/test_bubble_sort_parametrized_loop.py +++ b/code_to_optimize/tests/pytest/test_bubble_sort_parametrized_loop.py @@ -1,6 +1,7 @@ -from code_to_optimize.bubble_sort import sorter import pytest +from code_to_optimize.bubble_sort import sorter + @pytest.mark.parametrize( "input, expected_output", diff --git a/code_to_optimize/tests/unittest/test_bubble_sort_parametrized.py b/code_to_optimize/tests/unittest/test_bubble_sort_parametrized.py index 72c88e27c..59c86abc8 100644 --- a/code_to_optimize/tests/unittest/test_bubble_sort_parametrized.py +++ b/code_to_optimize/tests/unittest/test_bubble_sort_parametrized.py @@ -1,4 +1,5 @@ import unittest + from parameterized import parameterized from code_to_optimize.bubble_sort import sorter diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index f95eb72a6..f8cba4d26 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -3,7 +3,7 @@ from __future__ import annotations import json import os import platform -from typing import Any +from typing import TYPE_CHECKING, Any import requests from pydantic.dataclasses import dataclass @@ -11,11 +11,15 @@ from pydantic.json import pydantic_encoder from codeflash.cli_cmds.console import logger from codeflash.code_utils.env_utils import get_codeflash_api_key -from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.telemetry.posthog import ph from codeflash.version import __version__ as codeflash_version +if TYPE_CHECKING: + from pathlib import Path + + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.models.ExperimentMetadata import ExperimentMetadata + @dataclass(frozen=True) class OptimizedCandidate: @@ -161,8 +165,8 @@ class AiServiceClient: source_code_being_tested: str, function_to_optimize: FunctionToOptimize, helper_function_names: list[str], - module_path: str, - test_module_path: str, + module_path: Path, + test_module_path: Path, test_framework: str, test_timeout: int, trace_id: str, @@ -175,8 +179,8 @@ class AiServiceClient: - source_code_being_tested (str): The source code of the function being tested. - function_to_optimize (FunctionToOptimize): The function to optimize. - helper_function_names (list[Source]): List of helper function names. - - module_path (str): The module path where the function is located. - - test_module_path (str): The module path for the test code. + - module_path (Path): The module path where the function is located. + - test_module_path (Path): The module path for the test code. - test_framework (str): The test framework to use, e.g., "pytest". - test_timeout (int): The timeout for each test in seconds. - test_index (int): The index from 0-(n-1) if n tests are generated for a single trace_id diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index bfa344a08..7c745240a 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import os from functools import lru_cache @@ -8,10 +10,10 @@ import requests from pydantic.json import pydantic_encoder from requests import Response +from codeflash.cli_cmds.console import logger from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number from codeflash.code_utils.git_utils import get_repo_owner_and_name from codeflash.github.PrComment import FileDiffContent, PrComment -from codeflash.cli_cmds.console import logger if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local": CFAPI_BASE_URL = "http://localhost:3001" diff --git a/codeflash/code_utils/env_utils.py b/codeflash/code_utils/env_utils.py index e11a6d266..e31279a30 100644 --- a/codeflash/code_utils/env_utils.py +++ b/codeflash/code_utils/env_utils.py @@ -2,8 +2,8 @@ import os from functools import lru_cache from typing import Optional -from codeflash.code_utils.shell_utils import read_api_key_from_shell_config from codeflash.cli_cmds.console import logger +from codeflash.code_utils.shell_utils import read_api_key_from_shell_config @lru_cache(maxsize=1) diff --git a/codeflash/code_utils/github_utils.py b/codeflash/code_utils/github_utils.py index 50b3f0de7..4df374cc8 100644 --- a/codeflash/code_utils/github_utils.py +++ b/codeflash/code_utils/github_utils.py @@ -1,10 +1,10 @@ -from codeflash.cli_cmds.console import logger from typing import Optional from git import Repo from codeflash.api.cfapi import is_github_app_installed_on_repo from codeflash.cli_cmds.cli_common import apologize_and_exit +from codeflash.cli_cmds.console import logger from codeflash.code_utils.compat import LF from codeflash.code_utils.git_utils import get_repo_owner_and_name diff --git a/codeflash/main.py b/codeflash/main.py index ec0bf178b..c5578e008 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -4,14 +4,13 @@ solved problem, please reach out to us at careers@codeflash.ai. We're hiring! from pathlib import Path - from codeflash.cli_cmds.cli import parse_args, process_pyproject_config from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO, ask_run_end_to_end_test +from codeflash.cli_cmds.console import paneled_text from codeflash.code_utils.config_parser import parse_config_file from codeflash.optimization import optimizer from codeflash.telemetry import posthog from codeflash.telemetry.sentry import init_sentry -from codeflash.cli_cmds.console import paneled_text def main() -> None: diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index f8402f610..f6ce274c3 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -750,9 +750,10 @@ class Optimizer: for tests_in_file in function_to_tests.get(func_qualname): test_file_invocation_positions[tests_in_file.test_file].append(tests_in_file.position) for test_file, positions in test_file_invocation_positions.items(): + path_obj_test_file = Path(test_file) relevant_test_files_count += 1 success, injected_test = inject_profiling_into_existing_test( - test_file, + path_obj_test_file, positions, function_to_optimize, self.args.project_root, @@ -761,15 +762,15 @@ class Optimizer: if not success: continue new_test_path = Path(test_file).with_suffix(f"__perfinstrumented{Path(test_file).suffix}") - with new_test_path.open("w", encoding="utf8") as f: - f.write(injected_test) + with new_test_path.open("w", encoding="utf8") as _f: + _f.write(injected_test) unique_instrumented_test_files.add(new_test_path) - if not self.test_files.get_by_original_file_path(test_file): + if not self.test_files.get_by_original_file_path(path_obj_test_file): self.test_files.add( TestFile( instrumented_file_path=new_test_path, original_source=None, - original_file_path=test_file, + original_file_path=Path(test_file), test_type=TestType.EXISTING_UNIT_TEST, ), ) @@ -986,10 +987,10 @@ class Optimizer: first_test_types = [] first_test_functions = [] - Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin")).unlink( + get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink( missing_ok=True, ) - Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink( + get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink( missing_ok=True, ) @@ -1060,10 +1061,11 @@ class Optimizer: ) if best_runtime_until_now is None or total_candidate_timing < best_runtime_until_now: best_test_results = candidate_results - Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin")).unlink( + get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.bin")).unlink( missing_ok=True, ) - Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink( + + get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink( missing_ok=True, ) if not equal_results: @@ -1108,12 +1110,12 @@ class Optimizer: ) except subprocess.TimeoutExpired: logger.exception( - f'Error running tests in {", ".join(test_files)}.\nTimeout Error', + f'Error running tests in {", ".join(str(f) for f in test_files.test_files)}.\nTimeout Error', ) return TestResults() if run_result.returncode != 0: logger.debug( - f'Nonzero return code {run_result.returncode} when running tests in {", ".join([f.instrumented_file_path for f in test_files.test_files])}.\n' + f'Nonzero return code {run_result.returncode} when running tests in {", ".join([str(f.instrumented_file_path) for f in test_files.test_files])}.\n' f"stdout: {run_result.stdout}\n" f"stderr: {run_result.stderr}\n", ) diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index f91420f08..7462d2968 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -1,13 +1,12 @@ from __future__ import annotations -from codeflash.cli_cmds.console import logger -import os.path -import pathlib +from pathlib import Path from typing import Dict, Optional import git from codeflash.api import cfapi +from codeflash.cli_cmds.console import logger from codeflash.code_utils import env_utils from codeflash.code_utils.code_replacer import is_zero_diff from codeflash.code_utils.git_utils import ( @@ -30,7 +29,7 @@ def existing_tests_source_for( existing_tests_unique = set() if test_files: for test_file in test_files: - existing_tests_unique.add("- " + os.path.relpath(test_file.test_file, tests_root)) + existing_tests_unique.add("- " + str(Path(test_file.test_file).relative_to(tests_root))) return "\n".join(sorted(existing_tests_unique)) @@ -48,9 +47,9 @@ def check_create_pr( if pr_number is not None: logger.info(f"Suggesting changes to PR #{pr_number} ...") owner, repo = get_repo_owner_and_name(git_repo) - relative_path = str(pathlib.Path(os.path.relpath(explanation.file_path, git_root_dir())).as_posix()) + relative_path = Path(explanation.file_path).relative_to(git_root_dir()).as_posix() build_file_changes = { - str(pathlib.Path(os.path.relpath(p, git_root_dir())).as_posix()): FileDiffContent( + Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent( oldContent=original_code[p], newContent=new_code[p], ) @@ -93,14 +92,14 @@ def check_create_pr( if not check_and_push_branch(git_repo, wait_for_push=True): logger.warning("⏭️ Branch is not pushed, skipping PR creation...") return - relative_path = str(pathlib.Path(os.path.relpath(explanation.file_path, git_root_dir())).as_posix()) + relative_path = Path(explanation.file_path).relative_to(git_root_dir()).as_posix() base_branch = get_current_branch() response = cfapi.create_pr( owner=owner, repo=repo, base_branch=base_branch, file_changes={ - str(pathlib.Path(os.path.relpath(p, git_root_dir())).as_posix()): FileDiffContent( + Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent( oldContent=original_code[p], newContent=new_code[p], ) diff --git a/codeflash/telemetry/posthog.py b/codeflash/telemetry/posthog.py index b91ea1833..ee57533c3 100644 --- a/codeflash/telemetry/posthog.py +++ b/codeflash/telemetry/posthog.py @@ -1,10 +1,10 @@ import logging -from codeflash.cli_cmds.console import logger from typing import Any, Dict, Optional from posthog import Posthog from codeflash.api.cfapi import get_user_id +from codeflash.cli_cmds.console import logger from codeflash.version import __version__, __version_tuple__ _posthog = None diff --git a/codeflash/tracing/profile_stats.py b/codeflash/tracing/profile_stats.py index dd67dc43f..41581783f 100644 --- a/codeflash/tracing/profile_stats.py +++ b/codeflash/tracing/profile_stats.py @@ -1,14 +1,15 @@ import json -import os.path import pstats import sqlite3 from copy import copy +from pathlib import Path + from codeflash.cli_cmds.console import logger class ProfileStats(pstats.Stats): def __init__(self, trace_file_path: str, time_unit: str = "ns"): - assert os.path.isfile(trace_file_path), f"Trace file {trace_file_path} does not exist" + assert Path(trace_file_path).is_file(), f"Trace file {trace_file_path} does not exist" assert time_unit in ["ns", "us", "ms", "s"], f"Invalid time unit {time_unit}" self.trace_file_path = trace_file_path self.time_unit = time_unit @@ -72,8 +73,8 @@ class ProfileStats(pstats.Stats): return self -def get_trace_total_run_time_ns(trace_file_path: str) -> int: - if not os.path.isfile(trace_file_path): +def get_trace_total_run_time_ns(trace_file_path: Path) -> int: + if not trace_file_path.is_file(): return 0 con = sqlite3.connect(trace_file_path) cur = con.cursor() diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 353a83283..bce18746b 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -1,13 +1,14 @@ import datetime import decimal import enum -from codeflash.cli_cmds.console import logger import math import types from typing import Any import sentry_sdk +from codeflash.cli_cmds.console import logger + try: import numpy as np diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index 6201c1d5f..bade2271d 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -3,6 +3,7 @@ from __future__ import annotations import os import shlex import subprocess +from pathlib import Path from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME @@ -13,7 +14,7 @@ from codeflash.verification.test_results import TestType def run_tests( test_paths: TestFiles, test_framework: str, - cwd: str | None = None, + cwd: Path | None = None, test_env: dict[str, str] | None = None, pytest_timeout: int | None = None, pytest_cmd: str = "pytest", @@ -22,7 +23,7 @@ def run_tests( pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME, pytest_min_loops: int = 5, pytest_max_loops: int = 100_000, -) -> tuple[str, subprocess.CompletedProcess]: +) -> tuple[Path, subprocess.CompletedProcess]: assert test_framework in ["pytest", "unittest"] # TODO: Make this work for replay tests for i, test_file in enumerate(test_paths): @@ -30,10 +31,10 @@ def run_tests( only_run_these_test_functions and test_file.test_type == TestType.REPLAY_TEST ): # "__replay_test" in test_path: # TODO: This might not work for replay tests - test_paths[i] = test_file.instrumented_file_path + "::" + only_run_these_test_functions + test_paths[i] = str(test_file.instrumented_file_path) + "::" + only_run_these_test_functions if test_framework == "pytest": - result_file_path = get_run_tmp_file("pytest_results.xml") + result_file_path = get_run_tmp_file(Path("pytest_results.xml")) pytest_cmd_list = shlex.split(pytest_cmd, posix=os.name != "nt") pytest_test_env = test_env.copy() @@ -62,7 +63,7 @@ def run_tests( check=False, ) elif test_framework == "unittest": - result_file_path = get_run_tmp_file("unittest_results.xml") + result_file_path = get_run_tmp_file(Path("unittest_results.xml")) results = subprocess.run( ["python", "-m", "xmlrunner"] + (["-v"] if verbose else []) diff --git a/codeflash/verification/verifier.py b/codeflash/verification/verifier.py index cd59091bb..16aade07f 100644 --- a/codeflash/verification/verifier.py +++ b/codeflash/verification/verifier.py @@ -1,25 +1,29 @@ from __future__ import annotations import ast +from pathlib import Path +from typing import TYPE_CHECKING -from codeflash.api.aiservice import AiServiceClient from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path -from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.verification.verification_utils import ( ModifyInspiredTests, - TestConfig, delete_multiple_if_name_main, get_test_file_path, ) +if TYPE_CHECKING: + from codeflash.api.aiservice import AiServiceClient + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.verification.verification_utils import TestConfig + def generate_tests( aiservice_client: AiServiceClient, source_code_being_tested: str, function_to_optimize: FunctionToOptimize, helper_function_names: list[str], - module_path: str, + module_path: Path, test_cfg: TestConfig, test_timeout: int, use_cached_tests: bool, @@ -31,19 +35,22 @@ def generate_tests( if use_cached_tests: import importlib - module = importlib.import_module(module_path) + module = importlib.import_module(str(module_path)) generated_test_source = module.CACHED_TESTS instrumented_test_source = module.CACHED_INSTRUMENTED_TESTS - path = get_run_tmp_file("").replace("\\", "\\\\") # Escape backslash for windows paths + temp_run_dir = get_run_tmp_file(Path("")) + path = str(temp_run_dir).replace("\\", "\\\\") # Escape backslash for windows paths instrumented_test_source = instrumented_test_source.replace( "{codeflash_run_tmp_dir_client_side}", path, ) logger.info(f"Using cached tests from {module_path}.CACHED_TESTS") else: - test_module_path = module_name_from_file_path( - get_test_file_path(test_cfg.tests_root, function_to_optimize.function_name, 0), - test_cfg.project_root_path, + test_module_path = Path( + module_name_from_file_path( + get_test_file_path(test_cfg.tests_root, function_to_optimize.function_name, 0), + test_cfg.project_root_path, + ), ) response = aiservice_client.generate_regression_tests( source_code_being_tested=source_code_being_tested, @@ -58,7 +65,8 @@ def generate_tests( ) if response and isinstance(response, tuple) and len(response) == 2: generated_test_source, instrumented_test_source = response - path = get_run_tmp_file("").replace("\\", "\\\\") # Escape backslash for windows paths + temp_run_dir = get_run_tmp_file(Path("")) + path = str(temp_run_dir).replace("\\", "\\\\") instrumented_test_source = instrumented_test_source.replace( "{codeflash_run_tmp_dir_client_side}", path, diff --git a/pie_test_set/p02548.py b/pie_test_set/p02548.py index 68311769f..126911a07 100644 --- a/pie_test_set/p02548.py +++ b/pie_test_set/p02548.py @@ -1,5 +1,14 @@ def problem_p02548(input_data): - import sys, os, math, bisect, itertools, collections, heapq, queue, copy, array + import array + import bisect + import collections + import copy + import heapq + import itertools + import math + import os + import queue + import sys # from scipy.sparse.csgraph import csgraph_from_dense, floyd_warshall diff --git a/pie_test_set/p02624.py b/pie_test_set/p02624.py index 3656db152..06bfb341f 100644 --- a/pie_test_set/p02624.py +++ b/pie_test_set/p02624.py @@ -46,7 +46,6 @@ def problem_p02624(input_data): if sys.argv[-1] == "ONLINE_JUDGE": import numba - from numba.pycc import CC i8 = numba.int64 diff --git a/pie_test_set/p02639.py b/pie_test_set/p02639.py index 883dc062c..1acf1b9aa 100644 --- a/pie_test_set/p02639.py +++ b/pie_test_set/p02639.py @@ -1,9 +1,7 @@ def problem_p02639(input_data): - from sys import stdin, stdout - - from math import gcd, sqrt - from collections import deque + from math import gcd, sqrt + from sys import stdin, stdout input = stdin.readline diff --git a/pie_test_set/p02660.py b/pie_test_set/p02660.py index e303179a8..5d2091927 100644 --- a/pie_test_set/p02660.py +++ b/pie_test_set/p02660.py @@ -11,19 +11,13 @@ def problem_p02660(input_data): return list(map(int, input_data.split())) - from collections import defaultdict, deque - - from sys import exit - - import math - import copy - - from bisect import bisect_left, bisect_right - - from heapq import * - + import math import sys + from bisect import bisect_left, bisect_right + from collections import defaultdict, deque + from heapq import * + from sys import exit # sys.setrecursionlimit(1000000) diff --git a/pie_test_set/p02696.py b/pie_test_set/p02696.py index 00bde66c4..0eb44af1c 100644 --- a/pie_test_set/p02696.py +++ b/pie_test_set/p02696.py @@ -1,7 +1,6 @@ def problem_p02696(input_data): - from sys import stdin - import sys + from sys import stdin A, B, N = [int(x) for x in stdin.readline().rstrip().split()] diff --git a/pie_test_set/p02738.py b/pie_test_set/p02738.py index 676bb5b2e..e5a92d2cc 100644 --- a/pie_test_set/p02738.py +++ b/pie_test_set/p02738.py @@ -1,6 +1,5 @@ def problem_p02738(input_data): from functools import lru_cache, reduce - from itertools import accumulate N, M = list(map(int, input_data.split())) diff --git a/pie_test_set/p02782.py b/pie_test_set/p02782.py index 4d5fdbef4..30cb16629 100644 --- a/pie_test_set/p02782.py +++ b/pie_test_set/p02782.py @@ -1,16 +1,12 @@ def problem_p02782(input_data): + import collections + import heapq + import sys + from functools import cmp_to_key from sys import stdin - import sys - import numpy as np - import collections - - from functools import cmp_to_key - - import heapq - ## input functions for me def rsa(sep=""): diff --git a/pie_test_set/p02783.py b/pie_test_set/p02783.py index 62f643f80..263b3b8a2 100644 --- a/pie_test_set/p02783.py +++ b/pie_test_set/p02783.py @@ -1,8 +1,6 @@ def problem_p02783(input_data): import collections - import itertools as it - import math # import numpy as np diff --git a/pie_test_set/p02786.py b/pie_test_set/p02786.py index 52efccbc8..1e7fe2147 100644 --- a/pie_test_set/p02786.py +++ b/pie_test_set/p02786.py @@ -1,7 +1,6 @@ def problem_p02786(input_data): - from functools import lru_cache - import sys + from functools import lru_cache sys.setrecursionlimit(10**7) diff --git a/pie_test_set/p02840.py b/pie_test_set/p02840.py index 2b0ddf68f..c0ffcbfc9 100644 --- a/pie_test_set/p02840.py +++ b/pie_test_set/p02840.py @@ -1,6 +1,5 @@ def problem_p02840(input_data): from fractions import gcd - from itertools import accumulate n, x, d = list(map(int, input_data.split())) diff --git a/pie_test_set/p02900.py b/pie_test_set/p02900.py index 730d589d9..83e62127d 100644 --- a/pie_test_set/p02900.py +++ b/pie_test_set/p02900.py @@ -1,5 +1,5 @@ def problem_p02900(input_data): - from math import sqrt, ceil + from math import ceil, sqrt a, b = list(map(int, input_data.split())) diff --git a/pie_test_set/p02954.py b/pie_test_set/p02954.py index c49f08c1e..994e73e9e 100644 --- a/pie_test_set/p02954.py +++ b/pie_test_set/p02954.py @@ -1,7 +1,11 @@ def problem_p02954(input_data): - import sys, math, itertools, bisect, copy, re - - from collections import Counter, deque, defaultdict + import bisect + import copy + import itertools + import math + import re + import sys + from collections import Counter, defaultdict, deque # from itertools import accumulate, permutations, combinations, takewhile, compress, cycle diff --git a/pie_test_set/p02957.py b/pie_test_set/p02957.py index 07a9f5f45..2756be486 100644 --- a/pie_test_set/p02957.py +++ b/pie_test_set/p02957.py @@ -1,5 +1,18 @@ def problem_p02957(input_data): - import math, string, itertools, fractions, heapq, collections, re, array, bisect, sys, random, time, queue, copy + import array + import bisect + import collections + import copy + import fractions + import heapq + import itertools + import math + import queue + import random + import re + import string + import sys + import time sys.setrecursionlimit(10**7) diff --git a/pie_test_set/p02965.py b/pie_test_set/p02965.py index c5c1f82e7..7b7932a39 100644 --- a/pie_test_set/p02965.py +++ b/pie_test_set/p02965.py @@ -1,19 +1,12 @@ def problem_p02965(input_data): - from collections import defaultdict, deque, Counter - - from heapq import heappush, heappop, heapify - import math - - from bisect import bisect_left, bisect_right - import random - - from itertools import permutations, accumulate, combinations - - import sys - import string + import sys + from bisect import bisect_left, bisect_right + from collections import Counter, defaultdict, deque + from heapq import heapify, heappop, heappush + from itertools import accumulate, combinations, permutations INF = float("inf") diff --git a/pie_test_set/p02969.py b/pie_test_set/p02969.py index 09cbc32d7..35fb47934 100644 --- a/pie_test_set/p02969.py +++ b/pie_test_set/p02969.py @@ -3,9 +3,8 @@ def problem_p02969(input_data): input = sys.stdin.readline - import math - import collections + import math def I(): return int(eval(input_data)) diff --git a/pie_test_set/p02993.py b/pie_test_set/p02993.py index 6aeaee908..8b9326b50 100644 --- a/pie_test_set/p02993.py +++ b/pie_test_set/p02993.py @@ -1,7 +1,8 @@ def problem_p02993(input_data): #!/usr/bin/env python3 - import sys, math + import math + import sys input = lambda: sys.stdin.buffer.readline().rstrip().decode("utf-8") diff --git a/pie_test_set/p03016.py b/pie_test_set/p03016.py index a86231f30..7b563d904 100644 --- a/pie_test_set/p03016.py +++ b/pie_test_set/p03016.py @@ -1,27 +1,16 @@ def problem_p03016(input_data): - from collections import defaultdict, deque, Counter - - from heapq import heappush, heappop, heapify - - import math - import bisect - + import math import random - - from itertools import permutations, accumulate, combinations, product - - import sys - import string - + import sys from bisect import bisect_left, bisect_right - - from math import factorial, ceil, floor - - from operator import mul - + from collections import Counter, defaultdict, deque from functools import reduce + from heapq import heapify, heappop, heappush + from itertools import accumulate, combinations, permutations, product + from math import ceil, factorial, floor + from operator import mul sys.setrecursionlimit(2147483647) diff --git a/pie_test_set/p03088.py b/pie_test_set/p03088.py index 39a04cf99..5d044cbbb 100644 --- a/pie_test_set/p03088.py +++ b/pie_test_set/p03088.py @@ -1,7 +1,6 @@ def problem_p03088(input_data): - from itertools import product - from collections import defaultdict + from itertools import product MOD = 10**9 + 7 diff --git a/pie_test_set/p03206.py b/pie_test_set/p03206.py index 7245745cc..0f0ff476a 100644 --- a/pie_test_set/p03206.py +++ b/pie_test_set/p03206.py @@ -2,11 +2,10 @@ def problem_p03206(input_data): # encoding:utf-8 import copy + import random import numpy as np - import random - d = int(eval(input_data)) christmas = "Christmas" diff --git a/pie_test_set/p03213.py b/pie_test_set/p03213.py index 1c0cb0e74..5a9594056 100644 --- a/pie_test_set/p03213.py +++ b/pie_test_set/p03213.py @@ -1,7 +1,6 @@ def problem_p03213(input_data): - from operator import mul - from functools import reduce + from operator import mul nCr = {} diff --git a/pie_test_set/p03253.py b/pie_test_set/p03253.py index 7b5868f67..d4ac690dc 100644 --- a/pie_test_set/p03253.py +++ b/pie_test_set/p03253.py @@ -1,7 +1,6 @@ def problem_p03253(input_data): - from math import floor, sqrt - from collections import defaultdict + from math import floor, sqrt def factors(n): diff --git a/pie_test_set/p03286.py b/pie_test_set/p03286.py index fbbb2da62..111366d96 100644 --- a/pie_test_set/p03286.py +++ b/pie_test_set/p03286.py @@ -1,9 +1,8 @@ def problem_p03286(input_data): # coding: utf-8 - import sys - import bisect + import sys """Template""" diff --git a/pie_test_set/p03315.py b/pie_test_set/p03315.py index 383cc5ab0..b4970be2d 100644 --- a/pie_test_set/p03315.py +++ b/pie_test_set/p03315.py @@ -1,15 +1,10 @@ def problem_p03315(input_data): - import math - - import queue - import bisect - import heapq - - import time - import itertools + import math + import queue + import time mod = int(1e9 + 7) diff --git a/pie_test_set/p03502.py b/pie_test_set/p03502.py index a1c734221..d20e6472b 100644 --- a/pie_test_set/p03502.py +++ b/pie_test_set/p03502.py @@ -10,7 +10,6 @@ def problem_p03502(input_data): import sys # import itertools - import numpy as np read = sys.stdin.buffer.read diff --git a/pie_test_set/p03632.py b/pie_test_set/p03632.py index 4a1ac3387..8e1af69d0 100644 --- a/pie_test_set/p03632.py +++ b/pie_test_set/p03632.py @@ -3,9 +3,8 @@ def problem_p03632(input_data): sys.setrecursionlimit(4100000) - import math - import itertools + import math INF = float("inf") diff --git a/pie_test_set/p03666.py b/pie_test_set/p03666.py index 3cc3b2f34..38a62cfd4 100644 --- a/pie_test_set/p03666.py +++ b/pie_test_set/p03666.py @@ -1,5 +1,12 @@ def problem_p03666(input_data): - import sys, queue, math, copy, itertools, bisect, collections, heapq + import bisect + import collections + import copy + import heapq + import itertools + import math + import queue + import sys def main(): diff --git a/pie_test_set/p03797.py b/pie_test_set/p03797.py index 36aa71f3b..bebd18a3e 100644 --- a/pie_test_set/p03797.py +++ b/pie_test_set/p03797.py @@ -1,27 +1,16 @@ def problem_p03797(input_data): - from collections import defaultdict, deque, Counter - - from heapq import heappush, heappop, heapify - - import math - import bisect - + import math import random - - from itertools import permutations, accumulate, combinations, product - - import sys - import string - + import sys from bisect import bisect_left, bisect_right - - from math import factorial, ceil, floor - - from operator import mul - + from collections import Counter, defaultdict, deque from functools import reduce + from heapq import heapify, heappop, heappush + from itertools import accumulate, combinations, permutations, product + from math import ceil, factorial, floor + from operator import mul sys.setrecursionlimit(2147483647) diff --git a/pie_test_set/p03999.py b/pie_test_set/p03999.py index 5940ea9f8..3285c6c85 100644 --- a/pie_test_set/p03999.py +++ b/pie_test_set/p03999.py @@ -1,6 +1,6 @@ def problem_p03999(input_data): - from itertools import combinations, chain from functools import reduce + from itertools import chain, combinations def eval_str(string): diff --git a/pie_test_set/p04019.py b/pie_test_set/p04019.py index 0c0601fa4..47a04e0e5 100644 --- a/pie_test_set/p04019.py +++ b/pie_test_set/p04019.py @@ -1,7 +1,11 @@ def problem_p04019(input_data): #!/usr/bin/env python3 - import sys, math, itertools, collections, bisect + import bisect + import collections + import itertools + import math + import sys input = lambda: sys.stdin.buffer.readline().rstrip().decode("utf-8") diff --git a/pie_test_set/p04040.py b/pie_test_set/p04040.py index 5c2ad9879..2bc465b28 100644 --- a/pie_test_set/p04040.py +++ b/pie_test_set/p04040.py @@ -1,7 +1,6 @@ def problem_p04040(input_data): - from sys import stdin, stdout, stderr, setrecursionlimit - from functools import lru_cache + from sys import setrecursionlimit, stderr, stdin, stdout setrecursionlimit(10**7) diff --git a/pie_test_set/scripts/create_files_and_tests.py b/pie_test_set/scripts/create_files_and_tests.py index 93782dc19..9155fe46b 100644 --- a/pie_test_set/scripts/create_files_and_tests.py +++ b/pie_test_set/scripts/create_files_and_tests.py @@ -1,8 +1,8 @@ +import io import json -import os import subprocess import unittest.mock -import io +from pathlib import Path def create_files() -> None: @@ -12,10 +12,11 @@ def create_files() -> None: "original_data/val.jsonl", "original_data/train.jsonl", ]: - if not os.path.exists(jsonl_file): + jsonl_path = Path(jsonl_file) + if not jsonl_path.exists(): print(f"File {jsonl_file} does not exist.") continue - with open(jsonl_file, "r") as file: + with jsonl_path.open("r") as file: for line in file: test_case = json.loads(line) problem_id = test_case["problem_id"] @@ -26,12 +27,12 @@ def create_files() -> None: problems.add(problem_id) # Create a new Python file for each problem_id - file_path = f"../{problem_id}.py" - if os.path.exists(file_path): + file_path = Path(f"../{problem_id}.py") + if file_path.exists(): print(f"File {file_path} already exists.") continue - with open(file_path, "w") as code_file: + with file_path.open("w") as code_file: # Write the input code into a new function in the file code_file.write(f"def problem_{problem_id}(input_data):\n") # Replace input() calls with input_data handling @@ -43,42 +44,38 @@ def create_files() -> None: # Ensure the result is returned instead of printed try: # Run black to reformat the code file - subprocess.run(["black", file_path], check=True) + subprocess.run(["black", str(file_path)], check=True) except subprocess.CalledProcessError: print(f"Failed to format {file_path} with black.") - os.remove(file_path) + file_path.unlink() continue # Create test cases for each problem_id - test_dir = f"../tests/public_test_cases/{problem_id}" - if not os.path.exists(test_dir): + test_dir = Path(f"../tests/public_test_cases/{problem_id}") + if not test_dir.exists(): print(f"Directory {test_dir} does not exist.") continue test_files = sorted( - [f for f in os.listdir(test_dir) if f.startswith("input")], - key=lambda x: int(x.split(".")[1]), + [f for f in test_dir.iterdir() if f.name.startswith("input")], + key=lambda x: int(x.stem.split(".")[1]), ) - test_code_file_path = f"../tests/test_{problem_id}.py" - with open(test_code_file_path, "w") as test_code_file: + test_code_file_path = Path(f"../tests/test_{problem_id}.py") + with test_code_file_path.open("w") as test_code_file: test_code_file.write("\n") test_code_file.write( f"from code_to_optimize.pie_test_set.{problem_id} import problem_{problem_id}\n\n" ) for test_file in test_files: - input_num = test_file.split(".")[1] - output_file = f"output.{input_num}.txt" - with open(f"{test_dir}/{test_file}", "r") as input_f, open( - f"{test_dir}/{output_file}", "r" - ) as output_f: + input_num = test_file.stem.split(".")[1] + output_file = test_dir / f"output.{input_num}.txt" + with test_file.open("r") as input_f, output_file.open("r") as output_f: input_content = input_f.read() expected_output = output_f.read() if "\n" in input_content.strip() or "\n" in expected_output.strip(): - print( - f"Multiple lines detected in input or output for {problem_id}, skipping." - ) - os.remove(file_path) - if os.path.exists(test_code_file_path): - os.remove(test_code_file_path) + print(f"Multiple lines detected in input or output for {problem_id}, skipping.") + file_path.unlink() + if test_code_file_path.exists(): + test_code_file_path.unlink() break else: input_content = input_content.strip() @@ -86,20 +83,18 @@ def create_files() -> None: test_case_code = generate_test_case_code( problem_id, input_num, input_content, expected_output ) - with open(test_code_file_path, "a") as test_code_file: + with test_code_file_path.open("a") as test_code_file: test_code_file.write(test_case_code) try: # Run black to reformat the test file - subprocess.run(["black", test_code_file_path], check=True) + subprocess.run(["black", str(test_code_file_path)], check=True) except subprocess.CalledProcessError: print(f"Failed to format {test_code_file_path} with black.") - os.remove(test_code_file_path) + test_code_file_path.unlink() break -def generate_test_case_code( - problem_id: str, input_num: str, input_content: str, expected_output: str -) -> str: +def generate_test_case_code(problem_id: str, input_num: str, input_content: str, expected_output: str) -> str: return ( f"def test_problem_{problem_id}_{input_num}():\n" f" actual_output = problem_{problem_id}({input_content!r})\n" diff --git a/pie_test_set/scripts/run_pie_test_case.py b/pie_test_set/scripts/run_pie_test_case.py index cc46d73b2..5d13d1ee6 100644 --- a/pie_test_set/scripts/run_pie_test_case.py +++ b/pie_test_set/scripts/run_pie_test_case.py @@ -1,9 +1,9 @@ -import os import subprocess +from pathlib import Path def run_pie_test_case(script_path, test_input, expected_output): - assert os.path.exists(script_path), f"Script file does not exist: {script_path}" + assert Path(script_path).exists(), f"Script file does not exist: {script_path}" process = subprocess.Popen( ["python", script_path], stdin=subprocess.PIPE, @@ -17,6 +17,4 @@ def run_pie_test_case(script_path, test_input, expected_output): if stderr: print(f"Error in stderr: {stderr}") assert False, f"Script error: {stderr}" - assert ( - stdout.strip() == expected_output - ), f"Expected '{expected_output}' but got '{stdout.strip()}'" + assert stdout.strip() == expected_output, f"Expected '{expected_output}' but got '{stdout.strip()}'" diff --git a/tests/test_function_discovery.py b/tests/test_function_discovery.py index 51901e831..774440f24 100644 --- a/tests/test_function_discovery.py +++ b/tests/test_function_discovery.py @@ -1,4 +1,3 @@ -import os.path import tempfile from pathlib import Path @@ -124,8 +123,8 @@ def functionA(): only_get_this_function="A.functionA", test_cfg=test_config, ignore_paths=[Path("/bruh/")], - project_root=Path(os.path.dirname(f.name)), - module_root=Path(os.path.dirname(f.name)), + project_root=path_obj_name.parent, + module_root=path_obj_name.parent, ) assert len(functions) == 1 for file in functions: diff --git a/tests/test_shell_utils.py b/tests/test_shell_utils.py index 38a110b8e..5d5fac574 100644 --- a/tests/test_shell_utils.py +++ b/tests/test_shell_utils.py @@ -1,5 +1,6 @@ import os import unittest +from pathlib import Path from unittest.mock import mock_open, patch from returns.result import Failure, Success @@ -49,8 +50,9 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase): def tearDown(self): """Cleanup the temporary shell configuration file after testing.""" - if os.path.exists(self.test_rc_path): - os.remove(self.test_rc_path) + test_rc_path = Path(self.test_rc_path) + if test_rc_path.exists(): + test_rc_path.unlink() del os.environ["SHELL"] # Remove the SHELL environment variable def test_valid_api_key(self): diff --git a/tests/test_test_runner.py b/tests/test_test_runner.py index f87590a24..92cf6d89a 100644 --- a/tests/test_test_runner.py +++ b/tests/test_test_runner.py @@ -1,6 +1,6 @@ import os -import pathlib import tempfile +from pathlib import Path import pytest @@ -28,7 +28,7 @@ class TestUnittestRunnerSorter(unittest.TestCase): gc.enable() print(f"#####test_sorter__unit_test_0:TestUnittestRunnerSorter.test_sort:sorter:0#####{duration}^^^^^") """ - cur_dir_path = os.path.dirname(os.path.abspath(__file__)) + cur_dir_path = Path(__file__).resolve().parent config = TestConfig( tests_root=cur_dir_path, project_root_path=cur_dir_path, @@ -37,18 +37,20 @@ class TestUnittestRunnerSorter(unittest.TestCase): with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp: test_files = TestFiles( - test_files=[TestFile(instrumented_file_path=str(fp.name), test_type=TestType.EXISTING_UNIT_TEST)], + test_files=[ + TestFile(instrumented_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST) + ], ) fp.write(code.encode("utf-8")) fp.flush() result_file, process = run_tests( test_files, test_framework=config.test_framework, - cwd=config.project_root_path, + cwd=Path(config.project_root_path), ) results = parse_test_xml(result_file, test_files, config, process) assert results[0].did_pass, "Test did not pass as expected" - pathlib.Path(result_file).unlink(missing_ok=True) + result_file.unlink(missing_ok=True) @pytest.mark.skip(reason="not testing the actual code path") @@ -63,7 +65,7 @@ def test_sort(): output = sorter(arr) assert output == [0, 1, 2, 3, 4, 5] """ - cur_dir_path = os.path.dirname(os.path.abspath(__file__)) + cur_dir_path = Path(__file__).resolve().parent config = TestConfig( tests_root=cur_dir_path, project_root_path=cur_dir_path, @@ -71,7 +73,9 @@ def test_sort(): ) with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp: test_files = TestFiles( - test_files=[TestFile(instrumented_file_path=str(fp.name), test_type=TestType.EXISTING_UNIT_TEST)], + test_files=[ + TestFile(instrumented_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST) + ], ) fp.write(code.encode("utf-8")) fp.flush() @@ -79,7 +83,7 @@ def test_sort(): result_file, process = run_tests( test_files, test_framework=config.test_framework, - cwd=os.path.join(cur_dir_path), + cwd=Path(config.project_root_path), test_env=test_env, pytest_timeout=1, pytest_min_loops=1, @@ -93,4 +97,4 @@ def test_sort(): run_result=process, ) assert results[0].did_pass, "Test did not pass as expected" - pathlib.Path(result_file).unlink(missing_ok=True) + result_file.unlink(missing_ok=True) diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index 40b70e276..4c46148d8 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -1,17 +1,17 @@ import os -import pathlib import tempfile +from pathlib import Path from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.verification.verification_utils import TestConfig def test_unit_test_discovery_pytest(): - project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize" + project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize" tests_path = project_path / "tests" / "pytest" test_config = TestConfig( - tests_root=str(tests_path), - project_root_path=str(project_path), + tests_root=tests_path, + project_root_path=project_path, test_framework="pytest", ) tests = discover_unit_tests(test_config) @@ -20,14 +20,14 @@ def test_unit_test_discovery_pytest(): def test_unit_test_discovery_unittest(): - project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize" + project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize" test_path = project_path / "tests" / "unittest" test_config = TestConfig( - tests_root=str(project_path), - project_root_path=str(project_path), + tests_root=project_path, + project_root_path=project_path, test_framework="unittest", ) - os.chdir(str(project_path)) + os.chdir(project_path) tests = discover_unit_tests(test_config) # assert len(tests) > 0 # Unittest discovery within a pytest environment does not work @@ -36,7 +36,7 @@ def test_unit_test_discovery_unittest(): def test_discover_tests_pytest_with_temp_dir_root(): with tempfile.TemporaryDirectory() as tmpdirname: # Create a dummy test file - test_file_path = pathlib.Path(tmpdirname) / "test_dummy.py" + test_file_path = Path(tmpdirname) / "test_dummy.py" test_file_content = ( "import pytest\n" "from dummy_code import dummy_function\n\n" @@ -47,16 +47,17 @@ def test_discover_tests_pytest_with_temp_dir_root(): " assert dummy_function() is True\n" ) test_file_path.write_text(test_file_content) + path_obj_tempdirname = Path(tmpdirname) # Create a file that the test file is testing - code_file_path = pathlib.Path(tmpdirname) / "dummy_code.py" + code_file_path = path_obj_tempdirname / "dummy_code.py" code_file_content = "def dummy_function():\n return True\n" code_file_path.write_text(code_file_content) # Create a TestConfig with the temporary directory as the root test_config = TestConfig( - tests_root=str(tmpdirname), - project_root_path=str(tmpdirname), + tests_root=path_obj_tempdirname, + project_root_path=path_obj_tempdirname, test_framework="pytest", ) @@ -79,13 +80,14 @@ def test_discover_tests_pytest_with_temp_dir_root(): def test_discover_tests_pytest_with_multi_level_dirs(): with tempfile.TemporaryDirectory() as tmpdirname: + path_obj_tmpdirname = Path(tmpdirname) # Create multi-level directories - level1_dir = pathlib.Path(tmpdirname) / "level1" + level1_dir = path_obj_tmpdirname / "level1" level2_dir = level1_dir / "level2" level2_dir.mkdir(parents=True) # Create code files at each level - root_code_file_path = pathlib.Path(tmpdirname) / "root_code.py" + root_code_file_path = path_obj_tmpdirname / "root_code.py" root_code_file_content = "def root_function():\n return True\n" root_code_file_path.write_text(root_code_file_content) @@ -98,7 +100,7 @@ def test_discover_tests_pytest_with_multi_level_dirs(): level2_code_file_path.write_text(level2_code_file_content) # Create a test file at the root level - root_test_file_path = pathlib.Path(tmpdirname) / "test_root.py" + root_test_file_path = path_obj_tmpdirname / "test_root.py" root_test_file_content = ( "from root_code import root_function\n\n" "def test_root_function():\n" @@ -129,8 +131,8 @@ def test_discover_tests_pytest_with_multi_level_dirs(): # Create a TestConfig with the temporary directory as the root test_config = TestConfig( - tests_root=str(tmpdirname), - project_root_path=str(tmpdirname), + tests_root=path_obj_tmpdirname, + project_root_path=path_obj_tmpdirname, test_framework="pytest", ) @@ -150,15 +152,16 @@ def test_discover_tests_pytest_with_multi_level_dirs(): def test_discover_tests_pytest_dirs(): with tempfile.TemporaryDirectory() as tmpdirname: + path_obj_tmpdirname = Path(tmpdirname) # Create multi-level directories - level1_dir = pathlib.Path(tmpdirname) / "level1" + level1_dir = Path(tmpdirname) / "level1" level2_dir = level1_dir / "level2" level2_dir.mkdir(parents=True) level3_dir = level1_dir / "level3" level3_dir.mkdir(parents=True) # Create code files at each level - root_code_file_path = pathlib.Path(tmpdirname) / "root_code.py" + root_code_file_path = path_obj_tmpdirname / "root_code.py" root_code_file_content = "def root_function():\n return True\n" root_code_file_path.write_text(root_code_file_content) @@ -175,7 +178,7 @@ def test_discover_tests_pytest_dirs(): level3_code_file_path.write_text(level3_code_file_content) # Create a test file at the root level - root_test_file_path = pathlib.Path(tmpdirname) / "test_root.py" + root_test_file_path = path_obj_tmpdirname / "test_root.py" root_test_file_content = ( "from root_code import root_function\n\n" "def test_root_function():\n" @@ -215,8 +218,8 @@ def test_discover_tests_pytest_dirs(): # Create a TestConfig with the temporary directory as the root test_config = TestConfig( - tests_root=str(tmpdirname), - project_root_path=str(tmpdirname), + tests_root=path_obj_tmpdirname, + project_root_path=path_obj_tmpdirname, test_framework="pytest", ) @@ -239,13 +242,14 @@ def test_discover_tests_pytest_dirs(): def test_discover_tests_pytest_with_class(): with tempfile.TemporaryDirectory() as tmpdirname: + path_obj_tmpdirname = Path(tmpdirname) # Create a code file with a class - code_file_path = pathlib.Path(tmpdirname) / "some_class_code.py" + code_file_path = path_obj_tmpdirname / "some_class_code.py" code_file_content = "class SomeClass:\n def some_method(self):\n return True\n" code_file_path.write_text(code_file_content) # Create a test file with a test class and a test method - test_file_path = pathlib.Path(tmpdirname) / "test_some_class.py" + test_file_path = path_obj_tmpdirname / "test_some_class.py" test_file_content = ( "from some_class_code import SomeClass\n\n" "def test_some_method():\n" @@ -256,8 +260,8 @@ def test_discover_tests_pytest_with_class(): # Create a TestConfig with the temporary directory as the root test_config = TestConfig( - tests_root=str(tmpdirname), - project_root_path=str(tmpdirname), + tests_root=path_obj_tmpdirname, + project_root_path=path_obj_tmpdirname, test_framework="pytest", ) @@ -273,8 +277,9 @@ def test_discover_tests_pytest_with_class(): def test_discover_tests_pytest_with_double_nested_directories(): with tempfile.TemporaryDirectory() as tmpdirname: + path_obj_tmpdirname = Path(tmpdirname) # Create nested directories - nested_dir = pathlib.Path(tmpdirname) / "nested" / "more_nested" + nested_dir = path_obj_tmpdirname / "nested" / "more_nested" nested_dir.mkdir(parents=True) # Create a code file with a class in the nested directory @@ -294,8 +299,8 @@ def test_discover_tests_pytest_with_double_nested_directories(): # Create a TestConfig with the temporary directory as the root test_config = TestConfig( - tests_root=str(tmpdirname), - project_root_path=str(tmpdirname), + tests_root=path_obj_tmpdirname, + project_root_path=path_obj_tmpdirname, test_framework="pytest", ) @@ -313,8 +318,9 @@ def test_discover_tests_pytest_with_double_nested_directories(): def test_discover_tests_with_code_in_dir_and_test_in_subdir(): with tempfile.TemporaryDirectory() as tmpdirname: + path_obj_tmpdirname = Path(tmpdirname) # Create a directory for the code file - code_dir = pathlib.Path(tmpdirname) / "code" + code_dir = path_obj_tmpdirname / "code" code_dir.mkdir() # Create a code file in the code directory @@ -340,8 +346,8 @@ def test_discover_tests_with_code_in_dir_and_test_in_subdir(): # Create a TestConfig with the code directory as the root test_config = TestConfig( - tests_root=str(test_subdir), - project_root_path=str(tmpdirname), + tests_root=test_subdir, + project_root_path=path_obj_tmpdirname, test_framework="pytest", ) @@ -355,8 +361,9 @@ def test_discover_tests_with_code_in_dir_and_test_in_subdir(): def test_discover_tests_pytest_with_nested_class(): with tempfile.TemporaryDirectory() as tmpdirname: + path_obj_tmpdirname = Path(tmpdirname) # Create a code file with a nested class - code_file_path = pathlib.Path(tmpdirname) / "nested_class_code.py" + code_file_path = path_obj_tmpdirname / "nested_class_code.py" code_file_content = ( "class OuterClass:\n" " class InnerClass:\n" @@ -366,7 +373,7 @@ def test_discover_tests_pytest_with_nested_class(): code_file_path.write_text(code_file_content) # Create a test file with a test for the nested class method - test_file_path = pathlib.Path(tmpdirname) / "test_nested_class.py" + test_file_path = path_obj_tmpdirname / "test_nested_class.py" test_file_content = ( "from nested_class_code import OuterClass\n\n" "def test_inner_method():\n" @@ -377,8 +384,8 @@ def test_discover_tests_pytest_with_nested_class(): # Create a TestConfig with the temporary directory as the root test_config = TestConfig( - tests_root=str(tmpdirname), - project_root_path=str(tmpdirname), + tests_root=path_obj_tmpdirname, + project_root_path=path_obj_tmpdirname, test_framework="pytest", )