Merge branch 'main' of github.com:codeflash-ai/codeflash into bootstrapped-benchmarking
This commit is contained in:
commit
fbab91ab4f
14 changed files with 219 additions and 64 deletions
|
|
@ -3,7 +3,7 @@ Business Source License 1.1
|
|||
Parameters
|
||||
|
||||
Licensor: CodeFlash Inc.
|
||||
Licensed Work: Codeflash Client version 0.8.x
|
||||
Licensed Work: Codeflash Client version 0.9.x
|
||||
The Licensed Work is (c) 2024 CodeFlash Inc.
|
||||
|
||||
Additional Use Grant: None. Production use of the Licensed Work is only permitted
|
||||
|
|
@ -13,7 +13,7 @@ Additional Use Grant: None. Production use of the Licensed Work is only permitte
|
|||
Platform. Please visit codeflash.ai for further
|
||||
information.
|
||||
|
||||
Change Date: 2028-12-03
|
||||
Change Date: 2029-01-06
|
||||
|
||||
Change License: MIT
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,11 @@ from codeflash.cli_cmds.console import logger
|
|||
|
||||
def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
|
||||
if not full_qualified_name.startswith(module_name):
|
||||
raise ValueError(f"{full_qualified_name} does not start with {module_name}")
|
||||
msg = f"{full_qualified_name} does not start with {module_name}"
|
||||
raise ValueError(msg)
|
||||
if module_name == full_qualified_name:
|
||||
raise ValueError(f"{full_qualified_name} is the same as {module_name}")
|
||||
msg = f"{full_qualified_name} is the same as {module_name}"
|
||||
raise ValueError(msg)
|
||||
return full_qualified_name[len(module_name) + 1 :]
|
||||
|
||||
|
||||
|
|
@ -94,10 +96,15 @@ def is_class_defined_in_file(class_name: str, file_path: Path) -> bool:
|
|||
|
||||
|
||||
def validate_python_code(code: str) -> str:
|
||||
"""Validates a string of python code by attempting to compile it"""
|
||||
"""Validate a string of Python code by attempting to compile it."""
|
||||
try:
|
||||
compile(code, "<string>", "exec")
|
||||
except SyntaxError as e:
|
||||
msg = f"Invalid Python code: {e.msg} (line {e.lineno}, column {e.offset})"
|
||||
raise ValueError(msg) from e
|
||||
return code
|
||||
|
||||
|
||||
def cleanup_paths(paths: list[Path]) -> None:
|
||||
for path in paths:
|
||||
path.unlink(missing_ok=True)
|
||||
|
|
|
|||
|
|
@ -7,3 +7,4 @@ MAX_TEST_FUNCTION_RUNS = 50
|
|||
MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS = 100e6 # 100ms
|
||||
N_TESTS_TO_GENERATE = 2
|
||||
TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget
|
||||
COVERAGE_THRESHOLD = 50.0
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, Union
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
|
@ -19,7 +19,7 @@ class PrComment:
|
|||
winning_behavioral_test_results: TestResults
|
||||
winning_benchmarking_test_results: TestResults
|
||||
|
||||
def to_json(self) -> Dict[str, Union[str, Dict[str, Dict[str, int]], int, str]]:
|
||||
def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]:
|
||||
return {
|
||||
"optimization_explanation": self.optimization_explanation,
|
||||
"best_runtime": humanize_runtime(self.best_runtime),
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from codeflash.code_utils import env_utils
|
|||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code, find_preexisting_objects
|
||||
from codeflash.code_utils.code_replacer import normalize_code, normalize_node, replace_function_definitions_in_module
|
||||
from codeflash.code_utils.code_utils import (
|
||||
cleanup_paths,
|
||||
file_name_from_test_module_name,
|
||||
get_run_tmp_file,
|
||||
module_name_from_file_path,
|
||||
|
|
@ -62,7 +63,7 @@ from codeflash.models.models import (
|
|||
)
|
||||
from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
|
||||
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
|
||||
from codeflash.result.critic import performance_gain, quantity_of_tests_critic, speedup_critic
|
||||
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
|
||||
from codeflash.result.explanation import Explanation
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.verification.concolic_testing import generate_concolic_tests
|
||||
|
|
@ -335,23 +336,28 @@ class Optimizer:
|
|||
function_to_optimize=function_to_optimize, function_to_tests=function_to_all_tests
|
||||
)
|
||||
|
||||
baseline_result = self.establish_original_code_baseline(
|
||||
baseline_result = self.establish_original_code_baseline( # this needs better typing
|
||||
function_name=function_to_optimize_qualified_name,
|
||||
function_file_path=function_to_optimize.file_path,
|
||||
code_context=code_context,
|
||||
)
|
||||
|
||||
console.rule()
|
||||
if not is_successful(baseline_result):
|
||||
for generated_test_path in generated_test_paths:
|
||||
generated_test_path.unlink(missing_ok=True)
|
||||
for generated_perf_test_path in generated_perf_test_paths:
|
||||
generated_perf_test_path.unlink(missing_ok=True)
|
||||
paths_to_cleanup = (
|
||||
generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function)
|
||||
)
|
||||
|
||||
for instrumented_path in instrumented_unittests_created_for_function:
|
||||
instrumented_path.unlink(missing_ok=True)
|
||||
if not is_successful(baseline_result):
|
||||
cleanup_paths(paths_to_cleanup)
|
||||
return Failure(baseline_result.failure())
|
||||
|
||||
original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
|
||||
if isinstance(original_code_baseline, OriginalCodeBaseline) and not coverage_critic(
|
||||
original_code_baseline.coverage_results, self.args.test_framework
|
||||
):
|
||||
cleanup_paths(paths_to_cleanup)
|
||||
return Failure("The threshold for test coverage was not met.")
|
||||
|
||||
best_optimization = None
|
||||
|
||||
for u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.config_consts import MIN_IMPROVEMENT_THRESHOLD
|
||||
from codeflash.models.models import OptimizedCandidateResult
|
||||
from codeflash.code_utils.config_consts import COVERAGE_THRESHOLD, MIN_IMPROVEMENT_THRESHOLD
|
||||
from codeflash.models.models import CoverageData, OptimizedCandidateResult
|
||||
from codeflash.verification.test_results import TestType
|
||||
|
||||
|
||||
|
|
@ -53,7 +54,14 @@ def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult) -> bool
|
|||
elif pass_count >= 2:
|
||||
return True
|
||||
# If only one test passed, check if it's a REPLAY_TEST
|
||||
if pass_count == 1 and report[TestType.REPLAY_TEST]["passed"] == 1:
|
||||
return True
|
||||
return bool(pass_count == 1 and report[TestType.REPLAY_TEST]["passed"] == 1)
|
||||
|
||||
|
||||
def coverage_critic(original_code_coverage: CoverageData | None, test_framework: str) -> bool:
|
||||
"""Check if the coverage meets the threshold."""
|
||||
if test_framework == "unittest":
|
||||
logger.debug("Coverage critic is not implemented for unittest yet.")
|
||||
return True
|
||||
if original_code_coverage:
|
||||
return original_code_coverage.coverage >= COVERAGE_THRESHOLD
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import datetime
|
|||
import decimal
|
||||
import enum
|
||||
import math
|
||||
import re
|
||||
import types
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -168,7 +169,10 @@ def comparator(orig: Any, new: Any) -> bool:
|
|||
):
|
||||
return orig == new
|
||||
|
||||
if isinstance(orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone)):
|
||||
# re.Pattern can be made better by DFA Minimization and then comparing
|
||||
if isinstance(
|
||||
orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone, re.Pattern)
|
||||
):
|
||||
return orig == new
|
||||
|
||||
# If the object passed has a user defined __eq__ method, use that
|
||||
|
|
|
|||
|
|
@ -73,10 +73,6 @@ def generate_concolic_tests(
|
|||
ph("cli-optimize-concolic-tests", {"num_tests": num_discovered_concolic_tests})
|
||||
|
||||
else:
|
||||
(
|
||||
logger.warning(
|
||||
"Error running CrossHair Cover" f"{': ' + cover_result.stderr if cover_result.stderr else '.'}"
|
||||
)
|
||||
)
|
||||
logger.debug(f"Error running CrossHair Cover {': ' + cover_result.stderr if cover_result.stderr else '.'}")
|
||||
console.rule()
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
|
|
|
|||
|
|
@ -478,7 +478,7 @@ def parse_test_results(
|
|||
results = merge_test_results(test_results_xml, test_results_bin_file, test_config.test_framework)
|
||||
|
||||
all_args = False
|
||||
if coverage_file and source_file and code_context and function_name:
|
||||
if coverage_file and coverage_file.exists() and source_file and code_context and function_name:
|
||||
all_args = True
|
||||
coverage = CoverageData.load_from_coverage_file(
|
||||
coverage_file_path=coverage_file,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
# These version placeholders will be replaced by poetry-dynamic-versioning during `poetry build`.
|
||||
__version__ = "0.8.4"
|
||||
__version_tuple__ = (0, 8, 4)
|
||||
__version__ = "0.9.0"
|
||||
__version_tuple__ = (0, 9, 0)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ license = "BSL-1.1"
|
|||
authors = ["CodeFlash Inc. <contact@codeflash.ai>"]
|
||||
homepage = "https://codeflash.ai"
|
||||
readme = "README.md"
|
||||
packages = [{ include = "codeflash" }]
|
||||
packages = [{ include = "codeflash", format = ["sdist"] }]
|
||||
keywords = [
|
||||
"codeflash",
|
||||
"performance",
|
||||
|
|
@ -174,6 +174,7 @@ ignore = [
|
|||
"TD002",
|
||||
"TD003",
|
||||
"TD004",
|
||||
"PLR2004"
|
||||
]
|
||||
|
||||
[tool.ruff.lint.flake8-type-checking]
|
||||
|
|
|
|||
|
|
@ -5,17 +5,48 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from codeflash.code_utils.code_utils import (
|
||||
cleanup_paths,
|
||||
file_name_from_test_module_name,
|
||||
file_path_from_module_name,
|
||||
get_all_function_names,
|
||||
get_imports_from_file,
|
||||
get_qualified_name,
|
||||
get_run_tmp_file,
|
||||
is_class_defined_in_file,
|
||||
module_name_from_file_path,
|
||||
path_belongs_to_site_packages,
|
||||
file_name_from_test_module_name,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multiple_existing_and_non_existing_files(tmp_path: Path) -> list[Path]:
|
||||
existing_files = [tmp_path / f"existing_file{i}.txt" for i in range(3)]
|
||||
non_existing_files = [tmp_path / f"non_existing_file{i}.txt" for i in range(2)]
|
||||
for file in existing_files:
|
||||
file.touch()
|
||||
return existing_files + non_existing_files
|
||||
|
||||
def test_get_qualified_name_valid() -> None:
|
||||
module_name = "codeflash"
|
||||
full_qualified_name = "codeflash.utils.module"
|
||||
|
||||
result = get_qualified_name(module_name, full_qualified_name)
|
||||
assert result == "utils.module"
|
||||
|
||||
|
||||
def test_get_qualified_name_invalid_prefix() -> None:
|
||||
module_name = "codeflash"
|
||||
full_qualified_name = "otherflash.utils.module"
|
||||
with pytest.raises(ValueError, match="does not start with codeflash"):
|
||||
get_qualified_name(module_name, full_qualified_name)
|
||||
|
||||
|
||||
def test_get_qualified_name_same_name() -> None:
|
||||
module_name = "codeflash"
|
||||
full_qualified_name = "codeflash"
|
||||
with pytest.raises(ValueError, match="is the same as codeflash"):
|
||||
get_qualified_name(module_name, full_qualified_name)
|
||||
|
||||
# tests for module_name_from_file_path
|
||||
def test_module_name_from_file_path() -> None:
|
||||
project_root_path = Path("/Users/codeflashuser/PycharmProjects/codeflash")
|
||||
|
|
@ -49,7 +80,6 @@ def test_module_name_from_file_path_with_root_as_file() -> None:
|
|||
assert module_name == "code_utils"
|
||||
|
||||
|
||||
# tests for get_imports_from_file
|
||||
def test_get_imports_from_file_with_file_path(tmp_path: Path) -> None:
|
||||
test_file = tmp_path / "test_file.py"
|
||||
test_file.write_text("import os\nfrom sys import path\n")
|
||||
|
|
@ -61,8 +91,6 @@ def test_get_imports_from_file_with_file_path(tmp_path: Path) -> None:
|
|||
assert imports[0].names[0].name == "os"
|
||||
assert imports[1].module == "sys"
|
||||
assert imports[1].names[0].name == "path"
|
||||
|
||||
|
||||
def test_get_imports_from_file_with_file_string() -> None:
|
||||
file_string = "import os\nfrom sys import path\n"
|
||||
|
||||
|
|
@ -74,7 +102,6 @@ def test_get_imports_from_file_with_file_string() -> None:
|
|||
assert imports[1].module == "sys"
|
||||
assert imports[1].names[0].name == "path"
|
||||
|
||||
|
||||
def test_get_imports_from_file_with_file_ast() -> None:
|
||||
file_string = "import os\nfrom sys import path\n"
|
||||
file_ast = ast.parse(file_string)
|
||||
|
|
@ -87,8 +114,7 @@ def test_get_imports_from_file_with_file_ast() -> None:
|
|||
assert imports[1].module == "sys"
|
||||
assert imports[1].names[0].name == "path"
|
||||
|
||||
|
||||
def test_get_imports_from_file_with_syntax_error(caplog) -> None:
|
||||
def test_get_imports_from_file_with_syntax_error(caplog: pytest.LogCaptureFixture) -> None:
|
||||
file_string = "import os\nfrom sys import path\ninvalid syntax"
|
||||
|
||||
imports = get_imports_from_file(file_string=file_string)
|
||||
|
|
@ -147,8 +173,7 @@ async def bar():
|
|||
assert success is True
|
||||
assert function_names == ["foo", "bar"]
|
||||
|
||||
|
||||
def test_get_all_function_names_with_syntax_error(caplog) -> None:
|
||||
def test_get_all_function_names_with_syntax_error(caplog: pytest.LogCaptureFixture) -> None:
|
||||
code = """
|
||||
def foo():
|
||||
pass
|
||||
|
|
@ -209,25 +234,21 @@ def test_get_run_tmp_file_reuses_temp_directory() -> None:
|
|||
assert tmp_file_path1.parent.name.startswith("codeflash_")
|
||||
assert tmp_file_path1.parent.exists()
|
||||
|
||||
|
||||
# tests for path_belongs_to_site_packages
|
||||
def test_path_belongs_to_site_packages_with_site_package_path(monkeypatch) -> None:
|
||||
def test_path_belongs_to_site_packages_with_site_package_path(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
|
||||
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
|
||||
|
||||
file_path = Path("/usr/local/lib/python3.9/site-packages/some_package")
|
||||
assert path_belongs_to_site_packages(file_path) is True
|
||||
|
||||
|
||||
def test_path_belongs_to_site_packages_with_non_site_package_path(monkeypatch) -> None:
|
||||
def test_path_belongs_to_site_packages_with_non_site_package_path(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
|
||||
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
|
||||
|
||||
file_path = Path("/usr/local/lib/python3.9/other_directory/some_package")
|
||||
assert path_belongs_to_site_packages(file_path) is False
|
||||
|
||||
|
||||
def test_path_belongs_to_site_packages_with_relative_path(monkeypatch) -> None:
|
||||
def test_path_belongs_to_site_packages_with_relative_path(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
|
||||
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
|
||||
|
||||
|
|
@ -273,7 +294,7 @@ def test_is_class_defined_in_file_with_non_existing_file() -> None:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def base_dir(tmp_path):
|
||||
def base_dir(tmp_path: Path) -> Path:
|
||||
base_dir = tmp_path / "project"
|
||||
base_dir.mkdir(parents=True, exist_ok=True)
|
||||
(base_dir / "test_module.py").touch()
|
||||
|
|
@ -282,26 +303,32 @@ def base_dir(tmp_path):
|
|||
return base_dir
|
||||
|
||||
|
||||
def test_existing_module(base_dir):
|
||||
def test_existing_module(base_dir: Path) -> None:
|
||||
result = file_name_from_test_module_name("test_module", base_dir)
|
||||
assert result == base_dir / "test_module.py"
|
||||
|
||||
|
||||
def test_existing_submodule(base_dir):
|
||||
def test_existing_submodule(base_dir: Path) -> None:
|
||||
result = file_name_from_test_module_name("subdir.test_submodule", base_dir)
|
||||
assert result == base_dir / "subdir" / "test_submodule.py"
|
||||
|
||||
|
||||
def test_non_existing_module(base_dir):
|
||||
def test_non_existing_module(base_dir: Path) -> None:
|
||||
result = file_name_from_test_module_name("non_existing_module", base_dir)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_partial_module_name(base_dir):
|
||||
def test_partial_module_name(base_dir: Path) -> None:
|
||||
result = file_name_from_test_module_name("subdir.test_submodule.TestClass", base_dir)
|
||||
assert result == base_dir / "subdir" / "test_submodule.py"
|
||||
|
||||
|
||||
def test_partial_module_name2(base_dir):
|
||||
def test_partial_module_name2(base_dir: Path) -> None:
|
||||
result = file_name_from_test_module_name("subdir.test_submodule.TestClass.TestClass2", base_dir)
|
||||
assert result == base_dir / "subdir" / "test_submodule.py"
|
||||
|
||||
|
||||
def test_cleanup_paths(multiple_existing_and_non_existing_files: list[Path]) -> None:
|
||||
cleanup_paths(multiple_existing_and_non_existing_files)
|
||||
for file in multiple_existing_and_non_existing_files:
|
||||
assert not file.exists()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import dataclasses
|
||||
import datetime
|
||||
import decimal
|
||||
import re
|
||||
from enum import Enum, Flag, IntFlag, auto
|
||||
|
||||
import pydantic
|
||||
|
|
@ -189,6 +190,14 @@ def test_standard_python_library_objects():
|
|||
assert comparator(a, b)
|
||||
assert not comparator(a, c)
|
||||
|
||||
a: re.Pattern = re.compile("a")
|
||||
b: re.Pattern = re.compile("a")
|
||||
c: re.Pattern = re.compile("b")
|
||||
d: re.Pattern = re.compile("a", re.IGNORECASE)
|
||||
assert comparator(a, b)
|
||||
assert not comparator(a, c)
|
||||
assert not comparator(a, d)
|
||||
|
||||
|
||||
def test_numpy():
|
||||
try:
|
||||
|
|
@ -218,7 +227,6 @@ def test_numpy():
|
|||
j = np.float64(1.0)
|
||||
k = np.float64(1.0)
|
||||
assert not comparator(h, j)
|
||||
print(comparator(j, k))
|
||||
assert comparator(j, k)
|
||||
|
||||
l = np.int32(1)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,20 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock
|
||||
|
||||
from codeflash.code_utils.env_utils import get_pr_number
|
||||
from codeflash.models.models import OptimizedCandidateResult
|
||||
from codeflash.result.critic import performance_gain, quantity_of_tests_critic, speedup_critic
|
||||
from codeflash.models.models import (
|
||||
CodeOptimizationContext,
|
||||
CoverageData,
|
||||
CoverageStatus,
|
||||
FunctionCoverage,
|
||||
OptimizedCandidateResult,
|
||||
)
|
||||
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
|
||||
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, TestType
|
||||
|
||||
|
||||
def test_performance_gain():
|
||||
def test_performance_gain() -> None:
|
||||
assert performance_gain(original_runtime_ns=1000, optimized_runtime_ns=0) == 0.0
|
||||
|
||||
assert performance_gain(original_runtime_ns=1000, optimized_runtime_ns=500) == 1.0
|
||||
|
|
@ -18,7 +26,7 @@ def test_performance_gain():
|
|||
assert performance_gain(original_runtime_ns=1000, optimized_runtime_ns=1100) == -0.09090909090909091
|
||||
|
||||
|
||||
def test_speedup_critic():
|
||||
def test_speedup_critic() -> None:
|
||||
original_code_runtime = 1000
|
||||
best_runtime_until_now = 1000
|
||||
candidate_result = OptimizedCandidateResult(
|
||||
|
|
@ -58,7 +66,7 @@ def test_speedup_critic():
|
|||
assert speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now) # 6% improvement
|
||||
|
||||
|
||||
def test_generated_test_critic():
|
||||
def test_generated_test_critic() -> None:
|
||||
test_1 = FunctionTestInvocation(
|
||||
id=InvocationId(
|
||||
test_module_path="",
|
||||
|
|
@ -67,7 +75,7 @@ def test_generated_test_critic():
|
|||
function_getting_tested="sorter",
|
||||
iteration_id="",
|
||||
),
|
||||
file_name="test_1",
|
||||
file_name=Path("test_1"),
|
||||
did_pass=True,
|
||||
runtime=0,
|
||||
test_framework="pytest",
|
||||
|
|
@ -85,7 +93,7 @@ def test_generated_test_critic():
|
|||
function_getting_tested="sorter",
|
||||
iteration_id="",
|
||||
),
|
||||
file_name="test_2",
|
||||
file_name=Path("test_2"),
|
||||
did_pass=True,
|
||||
runtime=0,
|
||||
test_framework="pytest",
|
||||
|
|
@ -103,7 +111,7 @@ def test_generated_test_critic():
|
|||
function_getting_tested="sorter",
|
||||
iteration_id="",
|
||||
),
|
||||
file_name="test_3",
|
||||
file_name=Path("test_3"),
|
||||
did_pass=True,
|
||||
runtime=0,
|
||||
test_framework="pytest",
|
||||
|
|
@ -121,7 +129,7 @@ def test_generated_test_critic():
|
|||
function_getting_tested="sorter",
|
||||
iteration_id="",
|
||||
),
|
||||
file_name="test_4",
|
||||
file_name=Path("test_4"),
|
||||
did_pass=False,
|
||||
runtime=0,
|
||||
test_framework="pytest",
|
||||
|
|
@ -139,7 +147,7 @@ def test_generated_test_critic():
|
|||
function_getting_tested="sorter",
|
||||
iteration_id="",
|
||||
),
|
||||
file_name="test_5",
|
||||
file_name=Path("test_5"),
|
||||
did_pass=True,
|
||||
runtime=0,
|
||||
test_framework="pytest",
|
||||
|
|
@ -157,7 +165,7 @@ def test_generated_test_critic():
|
|||
function_getting_tested="sorter",
|
||||
iteration_id="",
|
||||
),
|
||||
file_name="test_6",
|
||||
file_name=Path("test_6"),
|
||||
did_pass=True,
|
||||
runtime=0,
|
||||
test_framework="pytest",
|
||||
|
|
@ -313,3 +321,92 @@ def test_generated_test_critic():
|
|||
assert quantity_of_tests_critic(candidate_result)
|
||||
|
||||
del os.environ["CODEFLASH_PR_NUMBER"]
|
||||
|
||||
|
||||
|
||||
def test_coverage_critic() -> None:
|
||||
mock_code_context = Mock(spec=CodeOptimizationContext)
|
||||
|
||||
passing_coverage = CoverageData(
|
||||
file_path=Path("test_file.py"),
|
||||
coverage=100.0,
|
||||
function_name="test_function",
|
||||
functions_being_tested=["function1", "function2"],
|
||||
graph={},
|
||||
code_context=mock_code_context,
|
||||
main_func_coverage=FunctionCoverage(
|
||||
name="test_function",
|
||||
coverage=100.0,
|
||||
executed_lines=[10],
|
||||
unexecuted_lines=[2],
|
||||
executed_branches=[[5]],
|
||||
unexecuted_branches=[[1]]
|
||||
),
|
||||
dependent_func_coverage=None,
|
||||
status=CoverageStatus.PARSED_SUCCESSFULLY
|
||||
)
|
||||
|
||||
assert coverage_critic(passing_coverage, "pytest") is True
|
||||
|
||||
border_coverage = CoverageData(
|
||||
file_path=Path("test_file.py"),
|
||||
coverage=50.0,
|
||||
function_name="test_function",
|
||||
functions_being_tested=["function1", "function2"],
|
||||
graph={},
|
||||
code_context=mock_code_context,
|
||||
main_func_coverage=FunctionCoverage(
|
||||
name="test_function",
|
||||
coverage=50.0,
|
||||
executed_lines=[10],
|
||||
unexecuted_lines=[2],
|
||||
executed_branches=[[5]],
|
||||
unexecuted_branches=[[1]]
|
||||
),
|
||||
dependent_func_coverage=None,
|
||||
status=CoverageStatus.PARSED_SUCCESSFULLY
|
||||
)
|
||||
|
||||
assert coverage_critic(border_coverage, "pytest") is True
|
||||
|
||||
failing_coverage = CoverageData(
|
||||
file_path=Path("test_file.py"),
|
||||
coverage=30.0,
|
||||
function_name="test_function",
|
||||
functions_being_tested=["function1", "function2"],
|
||||
graph={},
|
||||
code_context=mock_code_context,
|
||||
main_func_coverage=FunctionCoverage(
|
||||
name="test_function",
|
||||
coverage=0.0,
|
||||
executed_lines=[],
|
||||
unexecuted_lines=[10],
|
||||
executed_branches=[],
|
||||
unexecuted_branches=[[5]]
|
||||
),
|
||||
dependent_func_coverage=None,
|
||||
status=CoverageStatus.PARSED_SUCCESSFULLY
|
||||
)
|
||||
|
||||
assert coverage_critic(failing_coverage, "pytest") is False
|
||||
|
||||
unittest_coverage = CoverageData(
|
||||
file_path=Path("test_file.py"),
|
||||
coverage=0,
|
||||
function_name="test_function",
|
||||
functions_being_tested=["function1", "function2"],
|
||||
graph={},
|
||||
code_context=mock_code_context,
|
||||
main_func_coverage=FunctionCoverage(
|
||||
name="test_function",
|
||||
coverage=0,
|
||||
executed_lines=[10],
|
||||
unexecuted_lines=[2],
|
||||
executed_branches=[[5]],
|
||||
unexecuted_branches=[[1]]
|
||||
),
|
||||
dependent_func_coverage=None,
|
||||
status=CoverageStatus.PARSED_SUCCESSFULLY
|
||||
)
|
||||
|
||||
assert coverage_critic(unittest_coverage, "unittest") is True
|
||||
|
|
|
|||
Loading…
Reference in a new issue