Merge branch 'main' of github.com:codeflash-ai/codeflash into bootstrapped-benchmarking

This commit is contained in:
RD 2025-01-08 17:48:29 -08:00
commit fbab91ab4f
14 changed files with 219 additions and 64 deletions

View file

@ -3,7 +3,7 @@ Business Source License 1.1
Parameters
Licensor: CodeFlash Inc.
Licensed Work: Codeflash Client version 0.8.x
Licensed Work: Codeflash Client version 0.9.x
The Licensed Work is (c) 2024 CodeFlash Inc.
Additional Use Grant: None. Production use of the Licensed Work is only permitted
@ -13,7 +13,7 @@ Additional Use Grant: None. Production use of the Licensed Work is only permitte
Platform. Please visit codeflash.ai for further
information.
Change Date: 2028-12-03
Change Date: 2029-01-06
Change License: MIT

View file

@ -12,9 +12,11 @@ from codeflash.cli_cmds.console import logger
def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
if not full_qualified_name.startswith(module_name):
raise ValueError(f"{full_qualified_name} does not start with {module_name}")
msg = f"{full_qualified_name} does not start with {module_name}"
raise ValueError(msg)
if module_name == full_qualified_name:
raise ValueError(f"{full_qualified_name} is the same as {module_name}")
msg = f"{full_qualified_name} is the same as {module_name}"
raise ValueError(msg)
return full_qualified_name[len(module_name) + 1 :]
@ -94,10 +96,15 @@ def is_class_defined_in_file(class_name: str, file_path: Path) -> bool:
def validate_python_code(code: str) -> str:
"""Validates a string of python code by attempting to compile it"""
"""Validate a string of Python code by attempting to compile it."""
try:
compile(code, "<string>", "exec")
except SyntaxError as e:
msg = f"Invalid Python code: {e.msg} (line {e.lineno}, column {e.offset})"
raise ValueError(msg) from e
return code
def cleanup_paths(paths: list[Path]) -> None:
for path in paths:
path.unlink(missing_ok=True)

View file

@ -7,3 +7,4 @@ MAX_TEST_FUNCTION_RUNS = 50
MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS = 100e6 # 100ms
N_TESTS_TO_GENERATE = 2
TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget
COVERAGE_THRESHOLD = 50.0

View file

@ -1,4 +1,4 @@
from typing import Dict, Union
from typing import Union
from pydantic import BaseModel
from pydantic.dataclasses import dataclass
@ -19,7 +19,7 @@ class PrComment:
winning_behavioral_test_results: TestResults
winning_benchmarking_test_results: TestResults
def to_json(self) -> Dict[str, Union[str, Dict[str, Dict[str, int]], int, str]]:
def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]:
return {
"optimization_explanation": self.optimization_explanation,
"best_runtime": humanize_runtime(self.best_runtime),

View file

@ -25,6 +25,7 @@ from codeflash.code_utils import env_utils
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, extract_code, find_preexisting_objects
from codeflash.code_utils.code_replacer import normalize_code, normalize_node, replace_function_definitions_in_module
from codeflash.code_utils.code_utils import (
cleanup_paths,
file_name_from_test_module_name,
get_run_tmp_file,
module_name_from_file_path,
@ -62,7 +63,7 @@ from codeflash.models.models import (
)
from codeflash.optimization.function_context import get_constrained_function_context_and_helper_functions
from codeflash.result.create_pr import check_create_pr, existing_tests_source_for
from codeflash.result.critic import performance_gain, quantity_of_tests_critic, speedup_critic
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
from codeflash.result.explanation import Explanation
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.concolic_testing import generate_concolic_tests
@ -335,23 +336,28 @@ class Optimizer:
function_to_optimize=function_to_optimize, function_to_tests=function_to_all_tests
)
baseline_result = self.establish_original_code_baseline(
baseline_result = self.establish_original_code_baseline( # this needs better typing
function_name=function_to_optimize_qualified_name,
function_file_path=function_to_optimize.file_path,
code_context=code_context,
)
console.rule()
if not is_successful(baseline_result):
for generated_test_path in generated_test_paths:
generated_test_path.unlink(missing_ok=True)
for generated_perf_test_path in generated_perf_test_paths:
generated_perf_test_path.unlink(missing_ok=True)
paths_to_cleanup = (
generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function)
)
for instrumented_path in instrumented_unittests_created_for_function:
instrumented_path.unlink(missing_ok=True)
if not is_successful(baseline_result):
cleanup_paths(paths_to_cleanup)
return Failure(baseline_result.failure())
original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
if isinstance(original_code_baseline, OriginalCodeBaseline) and not coverage_critic(
original_code_baseline.coverage_results, self.args.test_framework
):
cleanup_paths(paths_to_cleanup)
return Failure("The threshold for test coverage was not met.")
best_optimization = None
for u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):

View file

@ -1,8 +1,9 @@
from __future__ import annotations
from codeflash.cli_cmds.console import logger
from codeflash.code_utils import env_utils
from codeflash.code_utils.config_consts import MIN_IMPROVEMENT_THRESHOLD
from codeflash.models.models import OptimizedCandidateResult
from codeflash.code_utils.config_consts import COVERAGE_THRESHOLD, MIN_IMPROVEMENT_THRESHOLD
from codeflash.models.models import CoverageData, OptimizedCandidateResult
from codeflash.verification.test_results import TestType
@ -53,7 +54,14 @@ def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult) -> bool
elif pass_count >= 2:
return True
# If only one test passed, check if it's a REPLAY_TEST
if pass_count == 1 and report[TestType.REPLAY_TEST]["passed"] == 1:
return True
return bool(pass_count == 1 and report[TestType.REPLAY_TEST]["passed"] == 1)
def coverage_critic(original_code_coverage: CoverageData | None, test_framework: str) -> bool:
"""Check if the coverage meets the threshold."""
if test_framework == "unittest":
logger.debug("Coverage critic is not implemented for unittest yet.")
return True
if original_code_coverage:
return original_code_coverage.coverage >= COVERAGE_THRESHOLD
return False

View file

@ -2,6 +2,7 @@ import datetime
import decimal
import enum
import math
import re
import types
from typing import Any
@ -168,7 +169,10 @@ def comparator(orig: Any, new: Any) -> bool:
):
return orig == new
if isinstance(orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone)):
# re.Pattern can be made better by DFA Minimization and then comparing
if isinstance(
orig, (datetime.datetime, datetime.date, datetime.timedelta, datetime.time, datetime.timezone, re.Pattern)
):
return orig == new
# If the object passed has a user defined __eq__ method, use that

View file

@ -73,10 +73,6 @@ def generate_concolic_tests(
ph("cli-optimize-concolic-tests", {"num_tests": num_discovered_concolic_tests})
else:
(
logger.warning(
"Error running CrossHair Cover" f"{': ' + cover_result.stderr if cover_result.stderr else '.'}"
)
)
logger.debug(f"Error running CrossHair Cover {': ' + cover_result.stderr if cover_result.stderr else '.'}")
console.rule()
return function_to_concolic_tests, concolic_test_suite_code

View file

@ -478,7 +478,7 @@ def parse_test_results(
results = merge_test_results(test_results_xml, test_results_bin_file, test_config.test_framework)
all_args = False
if coverage_file and source_file and code_context and function_name:
if coverage_file and coverage_file.exists() and source_file and code_context and function_name:
all_args = True
coverage = CoverageData.load_from_coverage_file(
coverage_file_path=coverage_file,

View file

@ -1,3 +1,3 @@
# These version placeholders will be replaced by poetry-dynamic-versioning during `poetry build`.
__version__ = "0.8.4"
__version_tuple__ = (0, 8, 4)
__version__ = "0.9.0"
__version_tuple__ = (0, 9, 0)

View file

@ -7,7 +7,7 @@ license = "BSL-1.1"
authors = ["CodeFlash Inc. <contact@codeflash.ai>"]
homepage = "https://codeflash.ai"
readme = "README.md"
packages = [{ include = "codeflash" }]
packages = [{ include = "codeflash", format = ["sdist"] }]
keywords = [
"codeflash",
"performance",
@ -174,6 +174,7 @@ ignore = [
"TD002",
"TD003",
"TD004",
"PLR2004"
]
[tool.ruff.lint.flake8-type-checking]

View file

@ -5,17 +5,48 @@ from pathlib import Path
import pytest
from codeflash.code_utils.code_utils import (
cleanup_paths,
file_name_from_test_module_name,
file_path_from_module_name,
get_all_function_names,
get_imports_from_file,
get_qualified_name,
get_run_tmp_file,
is_class_defined_in_file,
module_name_from_file_path,
path_belongs_to_site_packages,
file_name_from_test_module_name,
)
@pytest.fixture
def multiple_existing_and_non_existing_files(tmp_path: Path) -> list[Path]:
existing_files = [tmp_path / f"existing_file{i}.txt" for i in range(3)]
non_existing_files = [tmp_path / f"non_existing_file{i}.txt" for i in range(2)]
for file in existing_files:
file.touch()
return existing_files + non_existing_files
def test_get_qualified_name_valid() -> None:
module_name = "codeflash"
full_qualified_name = "codeflash.utils.module"
result = get_qualified_name(module_name, full_qualified_name)
assert result == "utils.module"
def test_get_qualified_name_invalid_prefix() -> None:
module_name = "codeflash"
full_qualified_name = "otherflash.utils.module"
with pytest.raises(ValueError, match="does not start with codeflash"):
get_qualified_name(module_name, full_qualified_name)
def test_get_qualified_name_same_name() -> None:
module_name = "codeflash"
full_qualified_name = "codeflash"
with pytest.raises(ValueError, match="is the same as codeflash"):
get_qualified_name(module_name, full_qualified_name)
# tests for module_name_from_file_path
def test_module_name_from_file_path() -> None:
project_root_path = Path("/Users/codeflashuser/PycharmProjects/codeflash")
@ -49,7 +80,6 @@ def test_module_name_from_file_path_with_root_as_file() -> None:
assert module_name == "code_utils"
# tests for get_imports_from_file
def test_get_imports_from_file_with_file_path(tmp_path: Path) -> None:
test_file = tmp_path / "test_file.py"
test_file.write_text("import os\nfrom sys import path\n")
@ -61,8 +91,6 @@ def test_get_imports_from_file_with_file_path(tmp_path: Path) -> None:
assert imports[0].names[0].name == "os"
assert imports[1].module == "sys"
assert imports[1].names[0].name == "path"
def test_get_imports_from_file_with_file_string() -> None:
file_string = "import os\nfrom sys import path\n"
@ -74,7 +102,6 @@ def test_get_imports_from_file_with_file_string() -> None:
assert imports[1].module == "sys"
assert imports[1].names[0].name == "path"
def test_get_imports_from_file_with_file_ast() -> None:
file_string = "import os\nfrom sys import path\n"
file_ast = ast.parse(file_string)
@ -87,8 +114,7 @@ def test_get_imports_from_file_with_file_ast() -> None:
assert imports[1].module == "sys"
assert imports[1].names[0].name == "path"
def test_get_imports_from_file_with_syntax_error(caplog) -> None:
def test_get_imports_from_file_with_syntax_error(caplog: pytest.LogCaptureFixture) -> None:
file_string = "import os\nfrom sys import path\ninvalid syntax"
imports = get_imports_from_file(file_string=file_string)
@ -147,8 +173,7 @@ async def bar():
assert success is True
assert function_names == ["foo", "bar"]
def test_get_all_function_names_with_syntax_error(caplog) -> None:
def test_get_all_function_names_with_syntax_error(caplog: pytest.LogCaptureFixture) -> None:
code = """
def foo():
pass
@ -209,25 +234,21 @@ def test_get_run_tmp_file_reuses_temp_directory() -> None:
assert tmp_file_path1.parent.name.startswith("codeflash_")
assert tmp_file_path1.parent.exists()
# tests for path_belongs_to_site_packages
def test_path_belongs_to_site_packages_with_site_package_path(monkeypatch) -> None:
def test_path_belongs_to_site_packages_with_site_package_path(monkeypatch: pytest.MonkeyPatch) -> None:
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
file_path = Path("/usr/local/lib/python3.9/site-packages/some_package")
assert path_belongs_to_site_packages(file_path) is True
def test_path_belongs_to_site_packages_with_non_site_package_path(monkeypatch) -> None:
def test_path_belongs_to_site_packages_with_non_site_package_path(monkeypatch: pytest.MonkeyPatch) -> None:
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
file_path = Path("/usr/local/lib/python3.9/other_directory/some_package")
assert path_belongs_to_site_packages(file_path) is False
def test_path_belongs_to_site_packages_with_relative_path(monkeypatch) -> None:
def test_path_belongs_to_site_packages_with_relative_path(monkeypatch: pytest.MonkeyPatch) -> None:
site_packages = [Path("/usr/local/lib/python3.9/site-packages")]
monkeypatch.setattr(site, "getsitepackages", lambda: site_packages)
@ -273,7 +294,7 @@ def test_is_class_defined_in_file_with_non_existing_file() -> None:
@pytest.fixture
def base_dir(tmp_path):
def base_dir(tmp_path: Path) -> Path:
base_dir = tmp_path / "project"
base_dir.mkdir(parents=True, exist_ok=True)
(base_dir / "test_module.py").touch()
@ -282,26 +303,32 @@ def base_dir(tmp_path):
return base_dir
def test_existing_module(base_dir):
def test_existing_module(base_dir: Path) -> None:
result = file_name_from_test_module_name("test_module", base_dir)
assert result == base_dir / "test_module.py"
def test_existing_submodule(base_dir):
def test_existing_submodule(base_dir: Path) -> None:
result = file_name_from_test_module_name("subdir.test_submodule", base_dir)
assert result == base_dir / "subdir" / "test_submodule.py"
def test_non_existing_module(base_dir):
def test_non_existing_module(base_dir: Path) -> None:
result = file_name_from_test_module_name("non_existing_module", base_dir)
assert result is None
def test_partial_module_name(base_dir):
def test_partial_module_name(base_dir: Path) -> None:
result = file_name_from_test_module_name("subdir.test_submodule.TestClass", base_dir)
assert result == base_dir / "subdir" / "test_submodule.py"
def test_partial_module_name2(base_dir):
def test_partial_module_name2(base_dir: Path) -> None:
result = file_name_from_test_module_name("subdir.test_submodule.TestClass.TestClass2", base_dir)
assert result == base_dir / "subdir" / "test_submodule.py"
def test_cleanup_paths(multiple_existing_and_non_existing_files: list[Path]) -> None:
cleanup_paths(multiple_existing_and_non_existing_files)
for file in multiple_existing_and_non_existing_files:
assert not file.exists()

View file

@ -1,6 +1,7 @@
import dataclasses
import datetime
import decimal
import re
from enum import Enum, Flag, IntFlag, auto
import pydantic
@ -189,6 +190,14 @@ def test_standard_python_library_objects():
assert comparator(a, b)
assert not comparator(a, c)
a: re.Pattern = re.compile("a")
b: re.Pattern = re.compile("a")
c: re.Pattern = re.compile("b")
d: re.Pattern = re.compile("a", re.IGNORECASE)
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
def test_numpy():
try:
@ -218,7 +227,6 @@ def test_numpy():
j = np.float64(1.0)
k = np.float64(1.0)
assert not comparator(h, j)
print(comparator(j, k))
assert comparator(j, k)
l = np.int32(1)

View file

@ -1,12 +1,20 @@
import os
from pathlib import Path
from unittest.mock import Mock
from codeflash.code_utils.env_utils import get_pr_number
from codeflash.models.models import OptimizedCandidateResult
from codeflash.result.critic import performance_gain, quantity_of_tests_critic, speedup_critic
from codeflash.models.models import (
CodeOptimizationContext,
CoverageData,
CoverageStatus,
FunctionCoverage,
OptimizedCandidateResult,
)
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, TestType
def test_performance_gain():
def test_performance_gain() -> None:
assert performance_gain(original_runtime_ns=1000, optimized_runtime_ns=0) == 0.0
assert performance_gain(original_runtime_ns=1000, optimized_runtime_ns=500) == 1.0
@ -18,7 +26,7 @@ def test_performance_gain():
assert performance_gain(original_runtime_ns=1000, optimized_runtime_ns=1100) == -0.09090909090909091
def test_speedup_critic():
def test_speedup_critic() -> None:
original_code_runtime = 1000
best_runtime_until_now = 1000
candidate_result = OptimizedCandidateResult(
@ -58,7 +66,7 @@ def test_speedup_critic():
assert speedup_critic(candidate_result, original_code_runtime, best_runtime_until_now) # 6% improvement
def test_generated_test_critic():
def test_generated_test_critic() -> None:
test_1 = FunctionTestInvocation(
id=InvocationId(
test_module_path="",
@ -67,7 +75,7 @@ def test_generated_test_critic():
function_getting_tested="sorter",
iteration_id="",
),
file_name="test_1",
file_name=Path("test_1"),
did_pass=True,
runtime=0,
test_framework="pytest",
@ -85,7 +93,7 @@ def test_generated_test_critic():
function_getting_tested="sorter",
iteration_id="",
),
file_name="test_2",
file_name=Path("test_2"),
did_pass=True,
runtime=0,
test_framework="pytest",
@ -103,7 +111,7 @@ def test_generated_test_critic():
function_getting_tested="sorter",
iteration_id="",
),
file_name="test_3",
file_name=Path("test_3"),
did_pass=True,
runtime=0,
test_framework="pytest",
@ -121,7 +129,7 @@ def test_generated_test_critic():
function_getting_tested="sorter",
iteration_id="",
),
file_name="test_4",
file_name=Path("test_4"),
did_pass=False,
runtime=0,
test_framework="pytest",
@ -139,7 +147,7 @@ def test_generated_test_critic():
function_getting_tested="sorter",
iteration_id="",
),
file_name="test_5",
file_name=Path("test_5"),
did_pass=True,
runtime=0,
test_framework="pytest",
@ -157,7 +165,7 @@ def test_generated_test_critic():
function_getting_tested="sorter",
iteration_id="",
),
file_name="test_6",
file_name=Path("test_6"),
did_pass=True,
runtime=0,
test_framework="pytest",
@ -313,3 +321,92 @@ def test_generated_test_critic():
assert quantity_of_tests_critic(candidate_result)
del os.environ["CODEFLASH_PR_NUMBER"]
def test_coverage_critic() -> None:
mock_code_context = Mock(spec=CodeOptimizationContext)
passing_coverage = CoverageData(
file_path=Path("test_file.py"),
coverage=100.0,
function_name="test_function",
functions_being_tested=["function1", "function2"],
graph={},
code_context=mock_code_context,
main_func_coverage=FunctionCoverage(
name="test_function",
coverage=100.0,
executed_lines=[10],
unexecuted_lines=[2],
executed_branches=[[5]],
unexecuted_branches=[[1]]
),
dependent_func_coverage=None,
status=CoverageStatus.PARSED_SUCCESSFULLY
)
assert coverage_critic(passing_coverage, "pytest") is True
border_coverage = CoverageData(
file_path=Path("test_file.py"),
coverage=50.0,
function_name="test_function",
functions_being_tested=["function1", "function2"],
graph={},
code_context=mock_code_context,
main_func_coverage=FunctionCoverage(
name="test_function",
coverage=50.0,
executed_lines=[10],
unexecuted_lines=[2],
executed_branches=[[5]],
unexecuted_branches=[[1]]
),
dependent_func_coverage=None,
status=CoverageStatus.PARSED_SUCCESSFULLY
)
assert coverage_critic(border_coverage, "pytest") is True
failing_coverage = CoverageData(
file_path=Path("test_file.py"),
coverage=30.0,
function_name="test_function",
functions_being_tested=["function1", "function2"],
graph={},
code_context=mock_code_context,
main_func_coverage=FunctionCoverage(
name="test_function",
coverage=0.0,
executed_lines=[],
unexecuted_lines=[10],
executed_branches=[],
unexecuted_branches=[[5]]
),
dependent_func_coverage=None,
status=CoverageStatus.PARSED_SUCCESSFULLY
)
assert coverage_critic(failing_coverage, "pytest") is False
unittest_coverage = CoverageData(
file_path=Path("test_file.py"),
coverage=0,
function_name="test_function",
functions_being_tested=["function1", "function2"],
graph={},
code_context=mock_code_context,
main_func_coverage=FunctionCoverage(
name="test_function",
coverage=0,
executed_lines=[10],
unexecuted_lines=[2],
executed_branches=[[5]],
unexecuted_branches=[[1]]
),
dependent_func_coverage=None,
status=CoverageStatus.PARSED_SUCCESSFULLY
)
assert coverage_critic(unittest_coverage, "unittest") is True