ruff format

This commit is contained in:
Kevin Turcios 2024-10-12 20:58:44 -05:00
parent 87bfc79f39
commit cd4db2291a
57 changed files with 276 additions and 267 deletions

View file

@ -1,9 +1,16 @@
from __future__ import annotations from __future__ import annotations
from sqlalchemy import create_engine, Integer, String, ForeignKey from sqlalchemy import ForeignKey, Integer, String, create_engine
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import (
from sqlalchemy.orm import sessionmaker, relationship, Relationship, Session, DeclarativeBase DeclarativeBase,
Mapped,
Relationship,
Session,
mapped_column,
relationship,
sessionmaker,
)
# Custom base class # Custom base class

View file

@ -1,10 +1,7 @@
from time import time
from typing import List from typing import List
from sqlalchemy import Boolean, Column, ForeignKey, Integer, Text, func from sqlalchemy import Boolean, Column, ForeignKey, Integer, Text, func
from sqlalchemy.orm import Session, relationship
from time import time
from sqlalchemy import Boolean, Column, ForeignKey, Integer, Text
from sqlalchemy.engine import Engine, create_engine from sqlalchemy.engine import Engine, create_engine
from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker
from sqlalchemy.orm.relationships import Relationship from sqlalchemy.orm.relationships import Relationship

View file

@ -1,6 +1,7 @@
from code_to_optimize.final_test_set.gradient import gradient
import numpy as np import numpy as np
from code_to_optimize.final_test_set.gradient import gradient
def test_simple_case(): def test_simple_case():
# Test case with simple values # Test case with simple values

View file

@ -1,6 +1,7 @@
from code_to_optimize.final_test_set.hamming_distance import _hamming_distance
import numpy as np import numpy as np
from code_to_optimize.final_test_set.hamming_distance import _hamming_distance
def test_no_differences(): def test_no_differences():
a = np.array([1, 2, 3, 4]) a = np.array([1, 2, 3, 4])

View file

@ -1,6 +1,7 @@
from code_to_optimize.final_test_set.integration import integrate_f
import pytest import pytest
from code_to_optimize.final_test_set.integration import integrate_f
def isclose(a, b, rel_tol=1e-5, abs_tol=0.0): def isclose(a, b, rel_tol=1e-5, abs_tol=0.0):
""" """

View file

@ -1,6 +1,7 @@
from code_to_optimize.final_test_set.matrix_multiplication import matrix_multiply
import pytest import pytest
from code_to_optimize.final_test_set.matrix_multiplication import matrix_multiply
def test_matrix_multiplication_basic(): def test_matrix_multiplication_basic():
A = [[1, 2], [3, 4]] A = [[1, 2], [3, 4]]

View file

@ -1,6 +1,7 @@
from code_to_optimize.final_test_set.standardize_name import standardize_name
import pytest import pytest
from code_to_optimize.final_test_set.standardize_name import standardize_name
def test_exact_match(): def test_exact_match():
assert standardize_name("Brattle St") == "Brattle St" assert standardize_name("Brattle St") == "Brattle St"

View file

@ -1,6 +1,7 @@
from code_to_optimize.bubble_sort import sorter
import pytest import pytest
from code_to_optimize.bubble_sort import sorter
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input, expected_output", "input, expected_output",

View file

@ -1,4 +1,5 @@
import unittest import unittest
from parameterized import parameterized from parameterized import parameterized
from code_to_optimize.bubble_sort import sorter from code_to_optimize.bubble_sort import sorter

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import json import json
import os import os
import platform import platform
from typing import Any from typing import TYPE_CHECKING, Any
import requests import requests
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
@ -11,11 +11,15 @@ from pydantic.json import pydantic_encoder
from codeflash.cli_cmds.console import logger from codeflash.cli_cmds.console import logger
from codeflash.code_utils.env_utils import get_codeflash_api_key from codeflash.code_utils.env_utils import get_codeflash_api_key
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.telemetry.posthog import ph from codeflash.telemetry.posthog import ph
from codeflash.version import __version__ as codeflash_version from codeflash.version import __version__ as codeflash_version
if TYPE_CHECKING:
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.ExperimentMetadata import ExperimentMetadata
@dataclass(frozen=True) @dataclass(frozen=True)
class OptimizedCandidate: class OptimizedCandidate:
@ -161,8 +165,8 @@ class AiServiceClient:
source_code_being_tested: str, source_code_being_tested: str,
function_to_optimize: FunctionToOptimize, function_to_optimize: FunctionToOptimize,
helper_function_names: list[str], helper_function_names: list[str],
module_path: str, module_path: Path,
test_module_path: str, test_module_path: Path,
test_framework: str, test_framework: str,
test_timeout: int, test_timeout: int,
trace_id: str, trace_id: str,
@ -175,8 +179,8 @@ class AiServiceClient:
- source_code_being_tested (str): The source code of the function being tested. - source_code_being_tested (str): The source code of the function being tested.
- function_to_optimize (FunctionToOptimize): The function to optimize. - function_to_optimize (FunctionToOptimize): The function to optimize.
- helper_function_names (list[Source]): List of helper function names. - helper_function_names (list[Source]): List of helper function names.
- module_path (str): The module path where the function is located. - module_path (Path): The module path where the function is located.
- test_module_path (str): The module path for the test code. - test_module_path (Path): The module path for the test code.
- test_framework (str): The test framework to use, e.g., "pytest". - test_framework (str): The test framework to use, e.g., "pytest".
- test_timeout (int): The timeout for each test in seconds. - test_timeout (int): The timeout for each test in seconds.
- test_index (int): The index from 0-(n-1) if n tests are generated for a single trace_id - test_index (int): The index from 0-(n-1) if n tests are generated for a single trace_id

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import json import json
import os import os
from functools import lru_cache from functools import lru_cache
@ -8,10 +10,10 @@ import requests
from pydantic.json import pydantic_encoder from pydantic.json import pydantic_encoder
from requests import Response from requests import Response
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
from codeflash.code_utils.git_utils import get_repo_owner_and_name from codeflash.code_utils.git_utils import get_repo_owner_and_name
from codeflash.github.PrComment import FileDiffContent, PrComment from codeflash.github.PrComment import FileDiffContent, PrComment
from codeflash.cli_cmds.console import logger
if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local": if os.environ.get("CODEFLASH_CFAPI_SERVER", default="prod").lower() == "local":
CFAPI_BASE_URL = "http://localhost:3001" CFAPI_BASE_URL = "http://localhost:3001"

View file

@ -2,8 +2,8 @@ import os
from functools import lru_cache from functools import lru_cache
from typing import Optional from typing import Optional
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
from codeflash.cli_cmds.console import logger from codeflash.cli_cmds.console import logger
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
@lru_cache(maxsize=1) @lru_cache(maxsize=1)

View file

@ -1,10 +1,10 @@
from codeflash.cli_cmds.console import logger
from typing import Optional from typing import Optional
from git import Repo from git import Repo
from codeflash.api.cfapi import is_github_app_installed_on_repo from codeflash.api.cfapi import is_github_app_installed_on_repo
from codeflash.cli_cmds.cli_common import apologize_and_exit from codeflash.cli_cmds.cli_common import apologize_and_exit
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.compat import LF from codeflash.code_utils.compat import LF
from codeflash.code_utils.git_utils import get_repo_owner_and_name from codeflash.code_utils.git_utils import get_repo_owner_and_name

View file

@ -4,14 +4,13 @@ solved problem, please reach out to us at careers@codeflash.ai. We're hiring!
from pathlib import Path from pathlib import Path
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO, ask_run_end_to_end_test from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO, ask_run_end_to_end_test
from codeflash.cli_cmds.console import paneled_text
from codeflash.code_utils.config_parser import parse_config_file from codeflash.code_utils.config_parser import parse_config_file
from codeflash.optimization import optimizer from codeflash.optimization import optimizer
from codeflash.telemetry import posthog from codeflash.telemetry import posthog
from codeflash.telemetry.sentry import init_sentry from codeflash.telemetry.sentry import init_sentry
from codeflash.cli_cmds.console import paneled_text
def main() -> None: def main() -> None:

View file

@ -750,9 +750,10 @@ class Optimizer:
for tests_in_file in function_to_tests.get(func_qualname): for tests_in_file in function_to_tests.get(func_qualname):
test_file_invocation_positions[tests_in_file.test_file].append(tests_in_file.position) test_file_invocation_positions[tests_in_file.test_file].append(tests_in_file.position)
for test_file, positions in test_file_invocation_positions.items(): for test_file, positions in test_file_invocation_positions.items():
path_obj_test_file = Path(test_file)
relevant_test_files_count += 1 relevant_test_files_count += 1
success, injected_test = inject_profiling_into_existing_test( success, injected_test = inject_profiling_into_existing_test(
test_file, path_obj_test_file,
positions, positions,
function_to_optimize, function_to_optimize,
self.args.project_root, self.args.project_root,
@ -761,15 +762,15 @@ class Optimizer:
if not success: if not success:
continue continue
new_test_path = Path(test_file).with_suffix(f"__perfinstrumented{Path(test_file).suffix}") new_test_path = Path(test_file).with_suffix(f"__perfinstrumented{Path(test_file).suffix}")
with new_test_path.open("w", encoding="utf8") as f: with new_test_path.open("w", encoding="utf8") as _f:
f.write(injected_test) _f.write(injected_test)
unique_instrumented_test_files.add(new_test_path) unique_instrumented_test_files.add(new_test_path)
if not self.test_files.get_by_original_file_path(test_file): if not self.test_files.get_by_original_file_path(path_obj_test_file):
self.test_files.add( self.test_files.add(
TestFile( TestFile(
instrumented_file_path=new_test_path, instrumented_file_path=new_test_path,
original_source=None, original_source=None,
original_file_path=test_file, original_file_path=Path(test_file),
test_type=TestType.EXISTING_UNIT_TEST, test_type=TestType.EXISTING_UNIT_TEST,
), ),
) )
@ -986,10 +987,10 @@ class Optimizer:
first_test_types = [] first_test_types = []
first_test_functions = [] first_test_functions = []
Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin")).unlink( get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(
missing_ok=True, missing_ok=True,
) )
Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink( get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(
missing_ok=True, missing_ok=True,
) )
@ -1060,10 +1061,11 @@ class Optimizer:
) )
if best_runtime_until_now is None or total_candidate_timing < best_runtime_until_now: if best_runtime_until_now is None or total_candidate_timing < best_runtime_until_now:
best_test_results = candidate_results best_test_results = candidate_results
Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.bin")).unlink( get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.bin")).unlink(
missing_ok=True, missing_ok=True,
) )
Path(get_run_tmp_file(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(
missing_ok=True, missing_ok=True,
) )
if not equal_results: if not equal_results:
@ -1108,12 +1110,12 @@ class Optimizer:
) )
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
logger.exception( logger.exception(
f'Error running tests in {", ".join(test_files)}.\nTimeout Error', f'Error running tests in {", ".join(str(f) for f in test_files.test_files)}.\nTimeout Error',
) )
return TestResults() return TestResults()
if run_result.returncode != 0: if run_result.returncode != 0:
logger.debug( logger.debug(
f'Nonzero return code {run_result.returncode} when running tests in {", ".join([f.instrumented_file_path for f in test_files.test_files])}.\n' f'Nonzero return code {run_result.returncode} when running tests in {", ".join([str(f.instrumented_file_path) for f in test_files.test_files])}.\n'
f"stdout: {run_result.stdout}\n" f"stdout: {run_result.stdout}\n"
f"stderr: {run_result.stderr}\n", f"stderr: {run_result.stderr}\n",
) )

View file

@ -1,13 +1,12 @@
from __future__ import annotations from __future__ import annotations
from codeflash.cli_cmds.console import logger from pathlib import Path
import os.path
import pathlib
from typing import Dict, Optional from typing import Dict, Optional
import git import git
from codeflash.api import cfapi from codeflash.api import cfapi
from codeflash.cli_cmds.console import logger
from codeflash.code_utils import env_utils from codeflash.code_utils import env_utils
from codeflash.code_utils.code_replacer import is_zero_diff from codeflash.code_utils.code_replacer import is_zero_diff
from codeflash.code_utils.git_utils import ( from codeflash.code_utils.git_utils import (
@ -30,7 +29,7 @@ def existing_tests_source_for(
existing_tests_unique = set() existing_tests_unique = set()
if test_files: if test_files:
for test_file in test_files: for test_file in test_files:
existing_tests_unique.add("- " + os.path.relpath(test_file.test_file, tests_root)) existing_tests_unique.add("- " + str(Path(test_file.test_file).relative_to(tests_root)))
return "\n".join(sorted(existing_tests_unique)) return "\n".join(sorted(existing_tests_unique))
@ -48,9 +47,9 @@ def check_create_pr(
if pr_number is not None: if pr_number is not None:
logger.info(f"Suggesting changes to PR #{pr_number} ...") logger.info(f"Suggesting changes to PR #{pr_number} ...")
owner, repo = get_repo_owner_and_name(git_repo) owner, repo = get_repo_owner_and_name(git_repo)
relative_path = str(pathlib.Path(os.path.relpath(explanation.file_path, git_root_dir())).as_posix()) relative_path = Path(explanation.file_path).relative_to(git_root_dir()).as_posix()
build_file_changes = { build_file_changes = {
str(pathlib.Path(os.path.relpath(p, git_root_dir())).as_posix()): FileDiffContent( Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
oldContent=original_code[p], oldContent=original_code[p],
newContent=new_code[p], newContent=new_code[p],
) )
@ -93,14 +92,14 @@ def check_create_pr(
if not check_and_push_branch(git_repo, wait_for_push=True): if not check_and_push_branch(git_repo, wait_for_push=True):
logger.warning("⏭️ Branch is not pushed, skipping PR creation...") logger.warning("⏭️ Branch is not pushed, skipping PR creation...")
return return
relative_path = str(pathlib.Path(os.path.relpath(explanation.file_path, git_root_dir())).as_posix()) relative_path = Path(explanation.file_path).relative_to(git_root_dir()).as_posix()
base_branch = get_current_branch() base_branch = get_current_branch()
response = cfapi.create_pr( response = cfapi.create_pr(
owner=owner, owner=owner,
repo=repo, repo=repo,
base_branch=base_branch, base_branch=base_branch,
file_changes={ file_changes={
str(pathlib.Path(os.path.relpath(p, git_root_dir())).as_posix()): FileDiffContent( Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
oldContent=original_code[p], oldContent=original_code[p],
newContent=new_code[p], newContent=new_code[p],
) )

View file

@ -1,10 +1,10 @@
import logging import logging
from codeflash.cli_cmds.console import logger
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from posthog import Posthog from posthog import Posthog
from codeflash.api.cfapi import get_user_id from codeflash.api.cfapi import get_user_id
from codeflash.cli_cmds.console import logger
from codeflash.version import __version__, __version_tuple__ from codeflash.version import __version__, __version_tuple__
_posthog = None _posthog = None

View file

@ -1,14 +1,15 @@
import json import json
import os.path
import pstats import pstats
import sqlite3 import sqlite3
from copy import copy from copy import copy
from pathlib import Path
from codeflash.cli_cmds.console import logger from codeflash.cli_cmds.console import logger
class ProfileStats(pstats.Stats): class ProfileStats(pstats.Stats):
def __init__(self, trace_file_path: str, time_unit: str = "ns"): def __init__(self, trace_file_path: str, time_unit: str = "ns"):
assert os.path.isfile(trace_file_path), f"Trace file {trace_file_path} does not exist" assert Path(trace_file_path).is_file(), f"Trace file {trace_file_path} does not exist"
assert time_unit in ["ns", "us", "ms", "s"], f"Invalid time unit {time_unit}" assert time_unit in ["ns", "us", "ms", "s"], f"Invalid time unit {time_unit}"
self.trace_file_path = trace_file_path self.trace_file_path = trace_file_path
self.time_unit = time_unit self.time_unit = time_unit
@ -72,8 +73,8 @@ class ProfileStats(pstats.Stats):
return self return self
def get_trace_total_run_time_ns(trace_file_path: str) -> int: def get_trace_total_run_time_ns(trace_file_path: Path) -> int:
if not os.path.isfile(trace_file_path): if not trace_file_path.is_file():
return 0 return 0
con = sqlite3.connect(trace_file_path) con = sqlite3.connect(trace_file_path)
cur = con.cursor() cur = con.cursor()

View file

@ -1,13 +1,14 @@
import datetime import datetime
import decimal import decimal
import enum import enum
from codeflash.cli_cmds.console import logger
import math import math
import types import types
from typing import Any from typing import Any
import sentry_sdk import sentry_sdk
from codeflash.cli_cmds.console import logger
try: try:
import numpy as np import numpy as np

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import os import os
import shlex import shlex
import subprocess import subprocess
from pathlib import Path
from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME
@ -13,7 +14,7 @@ from codeflash.verification.test_results import TestType
def run_tests( def run_tests(
test_paths: TestFiles, test_paths: TestFiles,
test_framework: str, test_framework: str,
cwd: str | None = None, cwd: Path | None = None,
test_env: dict[str, str] | None = None, test_env: dict[str, str] | None = None,
pytest_timeout: int | None = None, pytest_timeout: int | None = None,
pytest_cmd: str = "pytest", pytest_cmd: str = "pytest",
@ -22,7 +23,7 @@ def run_tests(
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME, pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME,
pytest_min_loops: int = 5, pytest_min_loops: int = 5,
pytest_max_loops: int = 100_000, pytest_max_loops: int = 100_000,
) -> tuple[str, subprocess.CompletedProcess]: ) -> tuple[Path, subprocess.CompletedProcess]:
assert test_framework in ["pytest", "unittest"] assert test_framework in ["pytest", "unittest"]
# TODO: Make this work for replay tests # TODO: Make this work for replay tests
for i, test_file in enumerate(test_paths): for i, test_file in enumerate(test_paths):
@ -30,10 +31,10 @@ def run_tests(
only_run_these_test_functions and test_file.test_type == TestType.REPLAY_TEST only_run_these_test_functions and test_file.test_type == TestType.REPLAY_TEST
): # "__replay_test" in test_path: ): # "__replay_test" in test_path:
# TODO: This might not work for replay tests # TODO: This might not work for replay tests
test_paths[i] = test_file.instrumented_file_path + "::" + only_run_these_test_functions test_paths[i] = str(test_file.instrumented_file_path) + "::" + only_run_these_test_functions
if test_framework == "pytest": if test_framework == "pytest":
result_file_path = get_run_tmp_file("pytest_results.xml") result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
pytest_cmd_list = shlex.split(pytest_cmd, posix=os.name != "nt") pytest_cmd_list = shlex.split(pytest_cmd, posix=os.name != "nt")
pytest_test_env = test_env.copy() pytest_test_env = test_env.copy()
@ -62,7 +63,7 @@ def run_tests(
check=False, check=False,
) )
elif test_framework == "unittest": elif test_framework == "unittest":
result_file_path = get_run_tmp_file("unittest_results.xml") result_file_path = get_run_tmp_file(Path("unittest_results.xml"))
results = subprocess.run( results = subprocess.run(
["python", "-m", "xmlrunner"] ["python", "-m", "xmlrunner"]
+ (["-v"] if verbose else []) + (["-v"] if verbose else [])

View file

@ -1,25 +1,29 @@
from __future__ import annotations from __future__ import annotations
import ast import ast
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.api.aiservice import AiServiceClient
from codeflash.cli_cmds.console import logger from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.verification.verification_utils import ( from codeflash.verification.verification_utils import (
ModifyInspiredTests, ModifyInspiredTests,
TestConfig,
delete_multiple_if_name_main, delete_multiple_if_name_main,
get_test_file_path, get_test_file_path,
) )
if TYPE_CHECKING:
from codeflash.api.aiservice import AiServiceClient
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.verification.verification_utils import TestConfig
def generate_tests( def generate_tests(
aiservice_client: AiServiceClient, aiservice_client: AiServiceClient,
source_code_being_tested: str, source_code_being_tested: str,
function_to_optimize: FunctionToOptimize, function_to_optimize: FunctionToOptimize,
helper_function_names: list[str], helper_function_names: list[str],
module_path: str, module_path: Path,
test_cfg: TestConfig, test_cfg: TestConfig,
test_timeout: int, test_timeout: int,
use_cached_tests: bool, use_cached_tests: bool,
@ -31,19 +35,22 @@ def generate_tests(
if use_cached_tests: if use_cached_tests:
import importlib import importlib
module = importlib.import_module(module_path) module = importlib.import_module(str(module_path))
generated_test_source = module.CACHED_TESTS generated_test_source = module.CACHED_TESTS
instrumented_test_source = module.CACHED_INSTRUMENTED_TESTS instrumented_test_source = module.CACHED_INSTRUMENTED_TESTS
path = get_run_tmp_file("").replace("\\", "\\\\") # Escape backslash for windows paths temp_run_dir = get_run_tmp_file(Path(""))
path = str(temp_run_dir).replace("\\", "\\\\") # Escape backslash for windows paths
instrumented_test_source = instrumented_test_source.replace( instrumented_test_source = instrumented_test_source.replace(
"{codeflash_run_tmp_dir_client_side}", "{codeflash_run_tmp_dir_client_side}",
path, path,
) )
logger.info(f"Using cached tests from {module_path}.CACHED_TESTS") logger.info(f"Using cached tests from {module_path}.CACHED_TESTS")
else: else:
test_module_path = module_name_from_file_path( test_module_path = Path(
get_test_file_path(test_cfg.tests_root, function_to_optimize.function_name, 0), module_name_from_file_path(
test_cfg.project_root_path, get_test_file_path(test_cfg.tests_root, function_to_optimize.function_name, 0),
test_cfg.project_root_path,
),
) )
response = aiservice_client.generate_regression_tests( response = aiservice_client.generate_regression_tests(
source_code_being_tested=source_code_being_tested, source_code_being_tested=source_code_being_tested,
@ -58,7 +65,8 @@ def generate_tests(
) )
if response and isinstance(response, tuple) and len(response) == 2: if response and isinstance(response, tuple) and len(response) == 2:
generated_test_source, instrumented_test_source = response generated_test_source, instrumented_test_source = response
path = get_run_tmp_file("").replace("\\", "\\\\") # Escape backslash for windows paths temp_run_dir = get_run_tmp_file(Path(""))
path = str(temp_run_dir).replace("\\", "\\\\")
instrumented_test_source = instrumented_test_source.replace( instrumented_test_source = instrumented_test_source.replace(
"{codeflash_run_tmp_dir_client_side}", "{codeflash_run_tmp_dir_client_side}",
path, path,

View file

@ -1,5 +1,14 @@
def problem_p02548(input_data): def problem_p02548(input_data):
import sys, os, math, bisect, itertools, collections, heapq, queue, copy, array import array
import bisect
import collections
import copy
import heapq
import itertools
import math
import os
import queue
import sys
# from scipy.sparse.csgraph import csgraph_from_dense, floyd_warshall # from scipy.sparse.csgraph import csgraph_from_dense, floyd_warshall

View file

@ -46,7 +46,6 @@ def problem_p02624(input_data):
if sys.argv[-1] == "ONLINE_JUDGE": if sys.argv[-1] == "ONLINE_JUDGE":
import numba import numba
from numba.pycc import CC from numba.pycc import CC
i8 = numba.int64 i8 = numba.int64

View file

@ -1,9 +1,7 @@
def problem_p02639(input_data): def problem_p02639(input_data):
from sys import stdin, stdout
from math import gcd, sqrt
from collections import deque from collections import deque
from math import gcd, sqrt
from sys import stdin, stdout
input = stdin.readline input = stdin.readline

View file

@ -11,19 +11,13 @@ def problem_p02660(input_data):
return list(map(int, input_data.split())) return list(map(int, input_data.split()))
from collections import defaultdict, deque
from sys import exit
import math
import copy import copy
import math
from bisect import bisect_left, bisect_right
from heapq import *
import sys import sys
from bisect import bisect_left, bisect_right
from collections import defaultdict, deque
from heapq import *
from sys import exit
# sys.setrecursionlimit(1000000) # sys.setrecursionlimit(1000000)

View file

@ -1,7 +1,6 @@
def problem_p02696(input_data): def problem_p02696(input_data):
from sys import stdin
import sys import sys
from sys import stdin
A, B, N = [int(x) for x in stdin.readline().rstrip().split()] A, B, N = [int(x) for x in stdin.readline().rstrip().split()]

View file

@ -1,6 +1,5 @@
def problem_p02738(input_data): def problem_p02738(input_data):
from functools import lru_cache, reduce from functools import lru_cache, reduce
from itertools import accumulate from itertools import accumulate
N, M = list(map(int, input_data.split())) N, M = list(map(int, input_data.split()))

View file

@ -1,16 +1,12 @@
def problem_p02782(input_data): def problem_p02782(input_data):
import collections
import heapq
import sys
from functools import cmp_to_key
from sys import stdin from sys import stdin
import sys
import numpy as np import numpy as np
import collections
from functools import cmp_to_key
import heapq
## input functions for me ## input functions for me
def rsa(sep=""): def rsa(sep=""):

View file

@ -1,8 +1,6 @@
def problem_p02783(input_data): def problem_p02783(input_data):
import collections import collections
import itertools as it import itertools as it
import math import math
# import numpy as np # import numpy as np

View file

@ -1,7 +1,6 @@
def problem_p02786(input_data): def problem_p02786(input_data):
from functools import lru_cache
import sys import sys
from functools import lru_cache
sys.setrecursionlimit(10**7) sys.setrecursionlimit(10**7)

View file

@ -1,6 +1,5 @@
def problem_p02840(input_data): def problem_p02840(input_data):
from fractions import gcd from fractions import gcd
from itertools import accumulate from itertools import accumulate
n, x, d = list(map(int, input_data.split())) n, x, d = list(map(int, input_data.split()))

View file

@ -1,5 +1,5 @@
def problem_p02900(input_data): def problem_p02900(input_data):
from math import sqrt, ceil from math import ceil, sqrt
a, b = list(map(int, input_data.split())) a, b = list(map(int, input_data.split()))

View file

@ -1,7 +1,11 @@
def problem_p02954(input_data): def problem_p02954(input_data):
import sys, math, itertools, bisect, copy, re import bisect
import copy
from collections import Counter, deque, defaultdict import itertools
import math
import re
import sys
from collections import Counter, defaultdict, deque
# from itertools import accumulate, permutations, combinations, takewhile, compress, cycle # from itertools import accumulate, permutations, combinations, takewhile, compress, cycle

View file

@ -1,5 +1,18 @@
def problem_p02957(input_data): def problem_p02957(input_data):
import math, string, itertools, fractions, heapq, collections, re, array, bisect, sys, random, time, queue, copy import array
import bisect
import collections
import copy
import fractions
import heapq
import itertools
import math
import queue
import random
import re
import string
import sys
import time
sys.setrecursionlimit(10**7) sys.setrecursionlimit(10**7)

View file

@ -1,19 +1,12 @@
def problem_p02965(input_data): def problem_p02965(input_data):
from collections import defaultdict, deque, Counter
from heapq import heappush, heappop, heapify
import math import math
from bisect import bisect_left, bisect_right
import random import random
from itertools import permutations, accumulate, combinations
import sys
import string import string
import sys
from bisect import bisect_left, bisect_right
from collections import Counter, defaultdict, deque
from heapq import heapify, heappop, heappush
from itertools import accumulate, combinations, permutations
INF = float("inf") INF = float("inf")

View file

@ -3,9 +3,8 @@ def problem_p02969(input_data):
input = sys.stdin.readline input = sys.stdin.readline
import math
import collections import collections
import math
def I(): def I():
return int(eval(input_data)) return int(eval(input_data))

View file

@ -1,7 +1,8 @@
def problem_p02993(input_data): def problem_p02993(input_data):
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys, math import math
import sys
input = lambda: sys.stdin.buffer.readline().rstrip().decode("utf-8") input = lambda: sys.stdin.buffer.readline().rstrip().decode("utf-8")

View file

@ -1,27 +1,16 @@
def problem_p03016(input_data): def problem_p03016(input_data):
from collections import defaultdict, deque, Counter
from heapq import heappush, heappop, heapify
import math
import bisect import bisect
import math
import random import random
from itertools import permutations, accumulate, combinations, product
import sys
import string import string
import sys
from bisect import bisect_left, bisect_right from bisect import bisect_left, bisect_right
from collections import Counter, defaultdict, deque
from math import factorial, ceil, floor
from operator import mul
from functools import reduce from functools import reduce
from heapq import heapify, heappop, heappush
from itertools import accumulate, combinations, permutations, product
from math import ceil, factorial, floor
from operator import mul
sys.setrecursionlimit(2147483647) sys.setrecursionlimit(2147483647)

View file

@ -1,7 +1,6 @@
def problem_p03088(input_data): def problem_p03088(input_data):
from itertools import product
from collections import defaultdict from collections import defaultdict
from itertools import product
MOD = 10**9 + 7 MOD = 10**9 + 7

View file

@ -2,11 +2,10 @@ def problem_p03206(input_data):
# encoding:utf-8 # encoding:utf-8
import copy import copy
import random
import numpy as np import numpy as np
import random
d = int(eval(input_data)) d = int(eval(input_data))
christmas = "Christmas" christmas = "Christmas"

View file

@ -1,7 +1,6 @@
def problem_p03213(input_data): def problem_p03213(input_data):
from operator import mul
from functools import reduce from functools import reduce
from operator import mul
nCr = {} nCr = {}

View file

@ -1,7 +1,6 @@
def problem_p03253(input_data): def problem_p03253(input_data):
from math import floor, sqrt
from collections import defaultdict from collections import defaultdict
from math import floor, sqrt
def factors(n): def factors(n):

View file

@ -1,9 +1,8 @@
def problem_p03286(input_data): def problem_p03286(input_data):
# coding: utf-8 # coding: utf-8
import sys
import bisect import bisect
import sys
"""Template""" """Template"""

View file

@ -1,15 +1,10 @@
def problem_p03315(input_data): def problem_p03315(input_data):
import math
import queue
import bisect import bisect
import heapq import heapq
import time
import itertools import itertools
import math
import queue
import time
mod = int(1e9 + 7) mod = int(1e9 + 7)

View file

@ -10,7 +10,6 @@ def problem_p03502(input_data):
import sys import sys
# import itertools # import itertools
import numpy as np import numpy as np
read = sys.stdin.buffer.read read = sys.stdin.buffer.read

View file

@ -3,9 +3,8 @@ def problem_p03632(input_data):
sys.setrecursionlimit(4100000) sys.setrecursionlimit(4100000)
import math
import itertools import itertools
import math
INF = float("inf") INF = float("inf")

View file

@ -1,5 +1,12 @@
def problem_p03666(input_data): def problem_p03666(input_data):
import sys, queue, math, copy, itertools, bisect, collections, heapq import bisect
import collections
import copy
import heapq
import itertools
import math
import queue
import sys
def main(): def main():

View file

@ -1,27 +1,16 @@
def problem_p03797(input_data): def problem_p03797(input_data):
from collections import defaultdict, deque, Counter
from heapq import heappush, heappop, heapify
import math
import bisect import bisect
import math
import random import random
from itertools import permutations, accumulate, combinations, product
import sys
import string import string
import sys
from bisect import bisect_left, bisect_right from bisect import bisect_left, bisect_right
from collections import Counter, defaultdict, deque
from math import factorial, ceil, floor
from operator import mul
from functools import reduce from functools import reduce
from heapq import heapify, heappop, heappush
from itertools import accumulate, combinations, permutations, product
from math import ceil, factorial, floor
from operator import mul
sys.setrecursionlimit(2147483647) sys.setrecursionlimit(2147483647)

View file

@ -1,6 +1,6 @@
def problem_p03999(input_data): def problem_p03999(input_data):
from itertools import combinations, chain
from functools import reduce from functools import reduce
from itertools import chain, combinations
def eval_str(string): def eval_str(string):

View file

@ -1,7 +1,11 @@
def problem_p04019(input_data): def problem_p04019(input_data):
#!/usr/bin/env python3 #!/usr/bin/env python3
import sys, math, itertools, collections, bisect import bisect
import collections
import itertools
import math
import sys
input = lambda: sys.stdin.buffer.readline().rstrip().decode("utf-8") input = lambda: sys.stdin.buffer.readline().rstrip().decode("utf-8")

View file

@ -1,7 +1,6 @@
def problem_p04040(input_data): def problem_p04040(input_data):
from sys import stdin, stdout, stderr, setrecursionlimit
from functools import lru_cache from functools import lru_cache
from sys import setrecursionlimit, stderr, stdin, stdout
setrecursionlimit(10**7) setrecursionlimit(10**7)

View file

@ -1,8 +1,8 @@
import io
import json import json
import os
import subprocess import subprocess
import unittest.mock import unittest.mock
import io from pathlib import Path
def create_files() -> None: def create_files() -> None:
@ -12,10 +12,11 @@ def create_files() -> None:
"original_data/val.jsonl", "original_data/val.jsonl",
"original_data/train.jsonl", "original_data/train.jsonl",
]: ]:
if not os.path.exists(jsonl_file): jsonl_path = Path(jsonl_file)
if not jsonl_path.exists():
print(f"File {jsonl_file} does not exist.") print(f"File {jsonl_file} does not exist.")
continue continue
with open(jsonl_file, "r") as file: with jsonl_path.open("r") as file:
for line in file: for line in file:
test_case = json.loads(line) test_case = json.loads(line)
problem_id = test_case["problem_id"] problem_id = test_case["problem_id"]
@ -26,12 +27,12 @@ def create_files() -> None:
problems.add(problem_id) problems.add(problem_id)
# Create a new Python file for each problem_id # Create a new Python file for each problem_id
file_path = f"../{problem_id}.py" file_path = Path(f"../{problem_id}.py")
if os.path.exists(file_path): if file_path.exists():
print(f"File {file_path} already exists.") print(f"File {file_path} already exists.")
continue continue
with open(file_path, "w") as code_file: with file_path.open("w") as code_file:
# Write the input code into a new function in the file # Write the input code into a new function in the file
code_file.write(f"def problem_{problem_id}(input_data):\n") code_file.write(f"def problem_{problem_id}(input_data):\n")
# Replace input() calls with input_data handling # Replace input() calls with input_data handling
@ -43,42 +44,38 @@ def create_files() -> None:
# Ensure the result is returned instead of printed # Ensure the result is returned instead of printed
try: try:
# Run black to reformat the code file # Run black to reformat the code file
subprocess.run(["black", file_path], check=True) subprocess.run(["black", str(file_path)], check=True)
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
print(f"Failed to format {file_path} with black.") print(f"Failed to format {file_path} with black.")
os.remove(file_path) file_path.unlink()
continue continue
# Create test cases for each problem_id # Create test cases for each problem_id
test_dir = f"../tests/public_test_cases/{problem_id}" test_dir = Path(f"../tests/public_test_cases/{problem_id}")
if not os.path.exists(test_dir): if not test_dir.exists():
print(f"Directory {test_dir} does not exist.") print(f"Directory {test_dir} does not exist.")
continue continue
test_files = sorted( test_files = sorted(
[f for f in os.listdir(test_dir) if f.startswith("input")], [f for f in test_dir.iterdir() if f.name.startswith("input")],
key=lambda x: int(x.split(".")[1]), key=lambda x: int(x.stem.split(".")[1]),
) )
test_code_file_path = f"../tests/test_{problem_id}.py" test_code_file_path = Path(f"../tests/test_{problem_id}.py")
with open(test_code_file_path, "w") as test_code_file: with test_code_file_path.open("w") as test_code_file:
test_code_file.write("\n") test_code_file.write("\n")
test_code_file.write( test_code_file.write(
f"from code_to_optimize.pie_test_set.{problem_id} import problem_{problem_id}\n\n" f"from code_to_optimize.pie_test_set.{problem_id} import problem_{problem_id}\n\n"
) )
for test_file in test_files: for test_file in test_files:
input_num = test_file.split(".")[1] input_num = test_file.stem.split(".")[1]
output_file = f"output.{input_num}.txt" output_file = test_dir / f"output.{input_num}.txt"
with open(f"{test_dir}/{test_file}", "r") as input_f, open( with test_file.open("r") as input_f, output_file.open("r") as output_f:
f"{test_dir}/{output_file}", "r"
) as output_f:
input_content = input_f.read() input_content = input_f.read()
expected_output = output_f.read() expected_output = output_f.read()
if "\n" in input_content.strip() or "\n" in expected_output.strip(): if "\n" in input_content.strip() or "\n" in expected_output.strip():
print( print(f"Multiple lines detected in input or output for {problem_id}, skipping.")
f"Multiple lines detected in input or output for {problem_id}, skipping." file_path.unlink()
) if test_code_file_path.exists():
os.remove(file_path) test_code_file_path.unlink()
if os.path.exists(test_code_file_path):
os.remove(test_code_file_path)
break break
else: else:
input_content = input_content.strip() input_content = input_content.strip()
@ -86,20 +83,18 @@ def create_files() -> None:
test_case_code = generate_test_case_code( test_case_code = generate_test_case_code(
problem_id, input_num, input_content, expected_output problem_id, input_num, input_content, expected_output
) )
with open(test_code_file_path, "a") as test_code_file: with test_code_file_path.open("a") as test_code_file:
test_code_file.write(test_case_code) test_code_file.write(test_case_code)
try: try:
# Run black to reformat the test file # Run black to reformat the test file
subprocess.run(["black", test_code_file_path], check=True) subprocess.run(["black", str(test_code_file_path)], check=True)
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
print(f"Failed to format {test_code_file_path} with black.") print(f"Failed to format {test_code_file_path} with black.")
os.remove(test_code_file_path) test_code_file_path.unlink()
break break
def generate_test_case_code( def generate_test_case_code(problem_id: str, input_num: str, input_content: str, expected_output: str) -> str:
problem_id: str, input_num: str, input_content: str, expected_output: str
) -> str:
return ( return (
f"def test_problem_{problem_id}_{input_num}():\n" f"def test_problem_{problem_id}_{input_num}():\n"
f" actual_output = problem_{problem_id}({input_content!r})\n" f" actual_output = problem_{problem_id}({input_content!r})\n"

View file

@ -1,9 +1,9 @@
import os
import subprocess import subprocess
from pathlib import Path
def run_pie_test_case(script_path, test_input, expected_output): def run_pie_test_case(script_path, test_input, expected_output):
assert os.path.exists(script_path), f"Script file does not exist: {script_path}" assert Path(script_path).exists(), f"Script file does not exist: {script_path}"
process = subprocess.Popen( process = subprocess.Popen(
["python", script_path], ["python", script_path],
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
@ -17,6 +17,4 @@ def run_pie_test_case(script_path, test_input, expected_output):
if stderr: if stderr:
print(f"Error in stderr: {stderr}") print(f"Error in stderr: {stderr}")
assert False, f"Script error: {stderr}" assert False, f"Script error: {stderr}"
assert ( assert stdout.strip() == expected_output, f"Expected '{expected_output}' but got '{stdout.strip()}'"
stdout.strip() == expected_output
), f"Expected '{expected_output}' but got '{stdout.strip()}'"

View file

@ -1,4 +1,3 @@
import os.path
import tempfile import tempfile
from pathlib import Path from pathlib import Path
@ -124,8 +123,8 @@ def functionA():
only_get_this_function="A.functionA", only_get_this_function="A.functionA",
test_cfg=test_config, test_cfg=test_config,
ignore_paths=[Path("/bruh/")], ignore_paths=[Path("/bruh/")],
project_root=Path(os.path.dirname(f.name)), project_root=path_obj_name.parent,
module_root=Path(os.path.dirname(f.name)), module_root=path_obj_name.parent,
) )
assert len(functions) == 1 assert len(functions) == 1
for file in functions: for file in functions:

View file

@ -1,5 +1,6 @@
import os import os
import unittest import unittest
from pathlib import Path
from unittest.mock import mock_open, patch from unittest.mock import mock_open, patch
from returns.result import Failure, Success from returns.result import Failure, Success
@ -49,8 +50,9 @@ class TestReadApiKeyFromShellConfig(unittest.TestCase):
def tearDown(self): def tearDown(self):
"""Cleanup the temporary shell configuration file after testing.""" """Cleanup the temporary shell configuration file after testing."""
if os.path.exists(self.test_rc_path): test_rc_path = Path(self.test_rc_path)
os.remove(self.test_rc_path) if test_rc_path.exists():
test_rc_path.unlink()
del os.environ["SHELL"] # Remove the SHELL environment variable del os.environ["SHELL"] # Remove the SHELL environment variable
def test_valid_api_key(self): def test_valid_api_key(self):

View file

@ -1,6 +1,6 @@
import os import os
import pathlib
import tempfile import tempfile
from pathlib import Path
import pytest import pytest
@ -28,7 +28,7 @@ class TestUnittestRunnerSorter(unittest.TestCase):
gc.enable() gc.enable()
print(f"#####test_sorter__unit_test_0:TestUnittestRunnerSorter.test_sort:sorter:0#####{duration}^^^^^") print(f"#####test_sorter__unit_test_0:TestUnittestRunnerSorter.test_sort:sorter:0#####{duration}^^^^^")
""" """
cur_dir_path = os.path.dirname(os.path.abspath(__file__)) cur_dir_path = Path(__file__).resolve().parent
config = TestConfig( config = TestConfig(
tests_root=cur_dir_path, tests_root=cur_dir_path,
project_root_path=cur_dir_path, project_root_path=cur_dir_path,
@ -37,18 +37,20 @@ class TestUnittestRunnerSorter(unittest.TestCase):
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp: with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
test_files = TestFiles( test_files = TestFiles(
test_files=[TestFile(instrumented_file_path=str(fp.name), test_type=TestType.EXISTING_UNIT_TEST)], test_files=[
TestFile(instrumented_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)
],
) )
fp.write(code.encode("utf-8")) fp.write(code.encode("utf-8"))
fp.flush() fp.flush()
result_file, process = run_tests( result_file, process = run_tests(
test_files, test_files,
test_framework=config.test_framework, test_framework=config.test_framework,
cwd=config.project_root_path, cwd=Path(config.project_root_path),
) )
results = parse_test_xml(result_file, test_files, config, process) results = parse_test_xml(result_file, test_files, config, process)
assert results[0].did_pass, "Test did not pass as expected" assert results[0].did_pass, "Test did not pass as expected"
pathlib.Path(result_file).unlink(missing_ok=True) result_file.unlink(missing_ok=True)
@pytest.mark.skip(reason="not testing the actual code path") @pytest.mark.skip(reason="not testing the actual code path")
@ -63,7 +65,7 @@ def test_sort():
output = sorter(arr) output = sorter(arr)
assert output == [0, 1, 2, 3, 4, 5] assert output == [0, 1, 2, 3, 4, 5]
""" """
cur_dir_path = os.path.dirname(os.path.abspath(__file__)) cur_dir_path = Path(__file__).resolve().parent
config = TestConfig( config = TestConfig(
tests_root=cur_dir_path, tests_root=cur_dir_path,
project_root_path=cur_dir_path, project_root_path=cur_dir_path,
@ -71,7 +73,9 @@ def test_sort():
) )
with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp: with tempfile.NamedTemporaryFile(prefix="test_xx", suffix=".py", dir=cur_dir_path) as fp:
test_files = TestFiles( test_files = TestFiles(
test_files=[TestFile(instrumented_file_path=str(fp.name), test_type=TestType.EXISTING_UNIT_TEST)], test_files=[
TestFile(instrumented_file_path=Path(fp.name), test_type=TestType.EXISTING_UNIT_TEST)
],
) )
fp.write(code.encode("utf-8")) fp.write(code.encode("utf-8"))
fp.flush() fp.flush()
@ -79,7 +83,7 @@ def test_sort():
result_file, process = run_tests( result_file, process = run_tests(
test_files, test_files,
test_framework=config.test_framework, test_framework=config.test_framework,
cwd=os.path.join(cur_dir_path), cwd=Path(config.project_root_path),
test_env=test_env, test_env=test_env,
pytest_timeout=1, pytest_timeout=1,
pytest_min_loops=1, pytest_min_loops=1,
@ -93,4 +97,4 @@ def test_sort():
run_result=process, run_result=process,
) )
assert results[0].did_pass, "Test did not pass as expected" assert results[0].did_pass, "Test did not pass as expected"
pathlib.Path(result_file).unlink(missing_ok=True) result_file.unlink(missing_ok=True)

View file

@ -1,17 +1,17 @@
import os import os
import pathlib
import tempfile import tempfile
from pathlib import Path
from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.verification.verification_utils import TestConfig from codeflash.verification.verification_utils import TestConfig
def test_unit_test_discovery_pytest(): def test_unit_test_discovery_pytest():
project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize" project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize"
tests_path = project_path / "tests" / "pytest" tests_path = project_path / "tests" / "pytest"
test_config = TestConfig( test_config = TestConfig(
tests_root=str(tests_path), tests_root=tests_path,
project_root_path=str(project_path), project_root_path=project_path,
test_framework="pytest", test_framework="pytest",
) )
tests = discover_unit_tests(test_config) tests = discover_unit_tests(test_config)
@ -20,14 +20,14 @@ def test_unit_test_discovery_pytest():
def test_unit_test_discovery_unittest(): def test_unit_test_discovery_unittest():
project_path = pathlib.Path(__file__).parent.parent.resolve() / "code_to_optimize" project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize"
test_path = project_path / "tests" / "unittest" test_path = project_path / "tests" / "unittest"
test_config = TestConfig( test_config = TestConfig(
tests_root=str(project_path), tests_root=project_path,
project_root_path=str(project_path), project_root_path=project_path,
test_framework="unittest", test_framework="unittest",
) )
os.chdir(str(project_path)) os.chdir(project_path)
tests = discover_unit_tests(test_config) tests = discover_unit_tests(test_config)
# assert len(tests) > 0 # assert len(tests) > 0
# Unittest discovery within a pytest environment does not work # Unittest discovery within a pytest environment does not work
@ -36,7 +36,7 @@ def test_unit_test_discovery_unittest():
def test_discover_tests_pytest_with_temp_dir_root(): def test_discover_tests_pytest_with_temp_dir_root():
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
# Create a dummy test file # Create a dummy test file
test_file_path = pathlib.Path(tmpdirname) / "test_dummy.py" test_file_path = Path(tmpdirname) / "test_dummy.py"
test_file_content = ( test_file_content = (
"import pytest\n" "import pytest\n"
"from dummy_code import dummy_function\n\n" "from dummy_code import dummy_function\n\n"
@ -47,16 +47,17 @@ def test_discover_tests_pytest_with_temp_dir_root():
" assert dummy_function() is True\n" " assert dummy_function() is True\n"
) )
test_file_path.write_text(test_file_content) test_file_path.write_text(test_file_content)
path_obj_tempdirname = Path(tmpdirname)
# Create a file that the test file is testing # Create a file that the test file is testing
code_file_path = pathlib.Path(tmpdirname) / "dummy_code.py" code_file_path = path_obj_tempdirname / "dummy_code.py"
code_file_content = "def dummy_function():\n return True\n" code_file_content = "def dummy_function():\n return True\n"
code_file_path.write_text(code_file_content) code_file_path.write_text(code_file_content)
# Create a TestConfig with the temporary directory as the root # Create a TestConfig with the temporary directory as the root
test_config = TestConfig( test_config = TestConfig(
tests_root=str(tmpdirname), tests_root=path_obj_tempdirname,
project_root_path=str(tmpdirname), project_root_path=path_obj_tempdirname,
test_framework="pytest", test_framework="pytest",
) )
@ -79,13 +80,14 @@ def test_discover_tests_pytest_with_temp_dir_root():
def test_discover_tests_pytest_with_multi_level_dirs(): def test_discover_tests_pytest_with_multi_level_dirs():
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
path_obj_tmpdirname = Path(tmpdirname)
# Create multi-level directories # Create multi-level directories
level1_dir = pathlib.Path(tmpdirname) / "level1" level1_dir = path_obj_tmpdirname / "level1"
level2_dir = level1_dir / "level2" level2_dir = level1_dir / "level2"
level2_dir.mkdir(parents=True) level2_dir.mkdir(parents=True)
# Create code files at each level # Create code files at each level
root_code_file_path = pathlib.Path(tmpdirname) / "root_code.py" root_code_file_path = path_obj_tmpdirname / "root_code.py"
root_code_file_content = "def root_function():\n return True\n" root_code_file_content = "def root_function():\n return True\n"
root_code_file_path.write_text(root_code_file_content) root_code_file_path.write_text(root_code_file_content)
@ -98,7 +100,7 @@ def test_discover_tests_pytest_with_multi_level_dirs():
level2_code_file_path.write_text(level2_code_file_content) level2_code_file_path.write_text(level2_code_file_content)
# Create a test file at the root level # Create a test file at the root level
root_test_file_path = pathlib.Path(tmpdirname) / "test_root.py" root_test_file_path = path_obj_tmpdirname / "test_root.py"
root_test_file_content = ( root_test_file_content = (
"from root_code import root_function\n\n" "from root_code import root_function\n\n"
"def test_root_function():\n" "def test_root_function():\n"
@ -129,8 +131,8 @@ def test_discover_tests_pytest_with_multi_level_dirs():
# Create a TestConfig with the temporary directory as the root # Create a TestConfig with the temporary directory as the root
test_config = TestConfig( test_config = TestConfig(
tests_root=str(tmpdirname), tests_root=path_obj_tmpdirname,
project_root_path=str(tmpdirname), project_root_path=path_obj_tmpdirname,
test_framework="pytest", test_framework="pytest",
) )
@ -150,15 +152,16 @@ def test_discover_tests_pytest_with_multi_level_dirs():
def test_discover_tests_pytest_dirs(): def test_discover_tests_pytest_dirs():
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
path_obj_tmpdirname = Path(tmpdirname)
# Create multi-level directories # Create multi-level directories
level1_dir = pathlib.Path(tmpdirname) / "level1" level1_dir = Path(tmpdirname) / "level1"
level2_dir = level1_dir / "level2" level2_dir = level1_dir / "level2"
level2_dir.mkdir(parents=True) level2_dir.mkdir(parents=True)
level3_dir = level1_dir / "level3" level3_dir = level1_dir / "level3"
level3_dir.mkdir(parents=True) level3_dir.mkdir(parents=True)
# Create code files at each level # Create code files at each level
root_code_file_path = pathlib.Path(tmpdirname) / "root_code.py" root_code_file_path = path_obj_tmpdirname / "root_code.py"
root_code_file_content = "def root_function():\n return True\n" root_code_file_content = "def root_function():\n return True\n"
root_code_file_path.write_text(root_code_file_content) root_code_file_path.write_text(root_code_file_content)
@ -175,7 +178,7 @@ def test_discover_tests_pytest_dirs():
level3_code_file_path.write_text(level3_code_file_content) level3_code_file_path.write_text(level3_code_file_content)
# Create a test file at the root level # Create a test file at the root level
root_test_file_path = pathlib.Path(tmpdirname) / "test_root.py" root_test_file_path = path_obj_tmpdirname / "test_root.py"
root_test_file_content = ( root_test_file_content = (
"from root_code import root_function\n\n" "from root_code import root_function\n\n"
"def test_root_function():\n" "def test_root_function():\n"
@ -215,8 +218,8 @@ def test_discover_tests_pytest_dirs():
# Create a TestConfig with the temporary directory as the root # Create a TestConfig with the temporary directory as the root
test_config = TestConfig( test_config = TestConfig(
tests_root=str(tmpdirname), tests_root=path_obj_tmpdirname,
project_root_path=str(tmpdirname), project_root_path=path_obj_tmpdirname,
test_framework="pytest", test_framework="pytest",
) )
@ -239,13 +242,14 @@ def test_discover_tests_pytest_dirs():
def test_discover_tests_pytest_with_class(): def test_discover_tests_pytest_with_class():
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
path_obj_tmpdirname = Path(tmpdirname)
# Create a code file with a class # Create a code file with a class
code_file_path = pathlib.Path(tmpdirname) / "some_class_code.py" code_file_path = path_obj_tmpdirname / "some_class_code.py"
code_file_content = "class SomeClass:\n def some_method(self):\n return True\n" code_file_content = "class SomeClass:\n def some_method(self):\n return True\n"
code_file_path.write_text(code_file_content) code_file_path.write_text(code_file_content)
# Create a test file with a test class and a test method # Create a test file with a test class and a test method
test_file_path = pathlib.Path(tmpdirname) / "test_some_class.py" test_file_path = path_obj_tmpdirname / "test_some_class.py"
test_file_content = ( test_file_content = (
"from some_class_code import SomeClass\n\n" "from some_class_code import SomeClass\n\n"
"def test_some_method():\n" "def test_some_method():\n"
@ -256,8 +260,8 @@ def test_discover_tests_pytest_with_class():
# Create a TestConfig with the temporary directory as the root # Create a TestConfig with the temporary directory as the root
test_config = TestConfig( test_config = TestConfig(
tests_root=str(tmpdirname), tests_root=path_obj_tmpdirname,
project_root_path=str(tmpdirname), project_root_path=path_obj_tmpdirname,
test_framework="pytest", test_framework="pytest",
) )
@ -273,8 +277,9 @@ def test_discover_tests_pytest_with_class():
def test_discover_tests_pytest_with_double_nested_directories(): def test_discover_tests_pytest_with_double_nested_directories():
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
path_obj_tmpdirname = Path(tmpdirname)
# Create nested directories # Create nested directories
nested_dir = pathlib.Path(tmpdirname) / "nested" / "more_nested" nested_dir = path_obj_tmpdirname / "nested" / "more_nested"
nested_dir.mkdir(parents=True) nested_dir.mkdir(parents=True)
# Create a code file with a class in the nested directory # Create a code file with a class in the nested directory
@ -294,8 +299,8 @@ def test_discover_tests_pytest_with_double_nested_directories():
# Create a TestConfig with the temporary directory as the root # Create a TestConfig with the temporary directory as the root
test_config = TestConfig( test_config = TestConfig(
tests_root=str(tmpdirname), tests_root=path_obj_tmpdirname,
project_root_path=str(tmpdirname), project_root_path=path_obj_tmpdirname,
test_framework="pytest", test_framework="pytest",
) )
@ -313,8 +318,9 @@ def test_discover_tests_pytest_with_double_nested_directories():
def test_discover_tests_with_code_in_dir_and_test_in_subdir(): def test_discover_tests_with_code_in_dir_and_test_in_subdir():
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
path_obj_tmpdirname = Path(tmpdirname)
# Create a directory for the code file # Create a directory for the code file
code_dir = pathlib.Path(tmpdirname) / "code" code_dir = path_obj_tmpdirname / "code"
code_dir.mkdir() code_dir.mkdir()
# Create a code file in the code directory # Create a code file in the code directory
@ -340,8 +346,8 @@ def test_discover_tests_with_code_in_dir_and_test_in_subdir():
# Create a TestConfig with the code directory as the root # Create a TestConfig with the code directory as the root
test_config = TestConfig( test_config = TestConfig(
tests_root=str(test_subdir), tests_root=test_subdir,
project_root_path=str(tmpdirname), project_root_path=path_obj_tmpdirname,
test_framework="pytest", test_framework="pytest",
) )
@ -355,8 +361,9 @@ def test_discover_tests_with_code_in_dir_and_test_in_subdir():
def test_discover_tests_pytest_with_nested_class(): def test_discover_tests_pytest_with_nested_class():
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
path_obj_tmpdirname = Path(tmpdirname)
# Create a code file with a nested class # Create a code file with a nested class
code_file_path = pathlib.Path(tmpdirname) / "nested_class_code.py" code_file_path = path_obj_tmpdirname / "nested_class_code.py"
code_file_content = ( code_file_content = (
"class OuterClass:\n" "class OuterClass:\n"
" class InnerClass:\n" " class InnerClass:\n"
@ -366,7 +373,7 @@ def test_discover_tests_pytest_with_nested_class():
code_file_path.write_text(code_file_content) code_file_path.write_text(code_file_content)
# Create a test file with a test for the nested class method # Create a test file with a test for the nested class method
test_file_path = pathlib.Path(tmpdirname) / "test_nested_class.py" test_file_path = path_obj_tmpdirname / "test_nested_class.py"
test_file_content = ( test_file_content = (
"from nested_class_code import OuterClass\n\n" "from nested_class_code import OuterClass\n\n"
"def test_inner_method():\n" "def test_inner_method():\n"
@ -377,8 +384,8 @@ def test_discover_tests_pytest_with_nested_class():
# Create a TestConfig with the temporary directory as the root # Create a TestConfig with the temporary directory as the root
test_config = TestConfig( test_config = TestConfig(
tests_root=str(tmpdirname), tests_root=path_obj_tmpdirname,
project_root_path=str(tmpdirname), project_root_path=path_obj_tmpdirname,
test_framework="pytest", test_framework="pytest",
) )