Merge branch 'main' into update-codeflash-yaml-concurrency
This commit is contained in:
commit
33b2eadc8b
10 changed files with 128 additions and 69 deletions
|
|
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, TypeVar
|
|||
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
|
|
@ -184,6 +184,7 @@ def replace_functions_in_file(
|
|||
if visitor.optim_body is None:
|
||||
msg = f"Unable to find function {function_name} in optimized code. Returning unchanged source code."
|
||||
logger.error(msg)
|
||||
console.rule()
|
||||
return source_code
|
||||
|
||||
transformer = OptimFunctionReplacer(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,14 @@
|
|||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# os-independent newline
|
||||
# important for any user-facing output or files we write
|
||||
# make sure to use this in f-strings e.g. f"some string{LF}"
|
||||
# you can use "[^f]\".*\{LF\}\" to find any lines in your code that use this without the f-string
|
||||
LF: str = os.linesep
|
||||
|
||||
|
||||
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
|
||||
|
||||
IS_POSIX = os.name != "nt"
|
||||
|
|
@ -14,8 +14,9 @@ import jedi
|
|||
from pydantic.dataclasses import dataclass
|
||||
from pytest import ExitCode
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
|
||||
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
|
||||
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile
|
||||
from codeflash.verification.test_results import TestType
|
||||
|
||||
|
|
@ -51,7 +52,7 @@ def discover_tests_pytest(
|
|||
tmp_pickle_path = get_run_tmp_file("collected_tests.pkl")
|
||||
subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
SAFE_SYS_EXECUTABLE,
|
||||
Path(__file__).parent / "pytest_new_process_discovery.py",
|
||||
str(project_root),
|
||||
str(tests_root),
|
||||
|
|
@ -76,6 +77,7 @@ def discover_tests_pytest(
|
|||
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}")
|
||||
else:
|
||||
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}")
|
||||
console.rule()
|
||||
else:
|
||||
logger.debug(f"Pytest collection exit code: {exitcode}")
|
||||
if pytest_rootdir is not None:
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@ from __future__ import annotations
|
|||
import ast
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from _ast import AsyncFunctionDef, ClassDef, FunctionDef
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
|
|
@ -14,7 +15,7 @@ import libcst as cst
|
|||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.api.cfapi import get_blocklisted_functions
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.cli_cmds.console import DEBUG_MODE, console, logger
|
||||
from codeflash.code_utils.code_utils import (
|
||||
is_class_defined_in_file,
|
||||
module_name_from_file_path,
|
||||
|
|
@ -159,46 +160,52 @@ def get_functions_to_optimize(
|
|||
sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1
|
||||
), "Only one of optimize_all, replay_test, or file should be provided"
|
||||
functions: dict[str, list[FunctionToOptimize]]
|
||||
if optimize_all:
|
||||
logger.info("Finding all functions in the module '%s'…", optimize_all)
|
||||
functions = get_all_files_and_functions(Path(optimize_all))
|
||||
elif replay_test is not None:
|
||||
functions = get_all_replay_test_functions(
|
||||
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter(action="ignore", category=SyntaxWarning)
|
||||
if optimize_all:
|
||||
logger.info("Finding all functions in the module '%s'…", optimize_all)
|
||||
console.rule()
|
||||
functions = get_all_files_and_functions(Path(optimize_all))
|
||||
elif replay_test is not None:
|
||||
functions = get_all_replay_test_functions(
|
||||
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
|
||||
)
|
||||
elif file is not None:
|
||||
logger.info("Finding all functions in the file '%s'…", file)
|
||||
console.rule()
|
||||
functions = find_all_functions_in_file(file)
|
||||
if only_get_this_function is not None:
|
||||
split_function = only_get_this_function.split(".")
|
||||
if len(split_function) > 2:
|
||||
msg = "Function name should be in the format 'function_name' or 'class_name.function_name'"
|
||||
raise ValueError(msg)
|
||||
if len(split_function) == 2:
|
||||
class_name, only_function_name = split_function
|
||||
else:
|
||||
class_name = None
|
||||
only_function_name = split_function[0]
|
||||
found_function = None
|
||||
for fn in functions.get(file, []):
|
||||
if only_function_name == fn.function_name and (
|
||||
class_name is None or class_name == fn.top_level_parent_name
|
||||
):
|
||||
found_function = fn
|
||||
if found_function is None:
|
||||
msg = f"Function {only_function_name} not found in file {file} or the function does not have a 'return' statement or is a property"
|
||||
raise ValueError(msg)
|
||||
functions[file] = [found_function]
|
||||
else:
|
||||
logger.info("Finding all functions modified in the current git diff ...")
|
||||
ph("cli-optimizing-git-diff")
|
||||
functions = get_functions_within_git_diff()
|
||||
filtered_modified_functions, functions_count = filter_functions(
|
||||
functions, test_cfg.tests_root, ignore_paths, project_root, module_root
|
||||
)
|
||||
elif file is not None:
|
||||
logger.info("Finding all functions in the file '%s'…", file)
|
||||
console.rule()
|
||||
functions = find_all_functions_in_file(file)
|
||||
if only_get_this_function is not None:
|
||||
split_function = only_get_this_function.split(".")
|
||||
if len(split_function) > 2:
|
||||
msg = "Function name should be in the format 'function_name' or 'class_name.function_name'"
|
||||
raise ValueError(msg)
|
||||
if len(split_function) == 2:
|
||||
class_name, only_function_name = split_function
|
||||
else:
|
||||
class_name = None
|
||||
only_function_name = split_function[0]
|
||||
found_function = None
|
||||
for fn in functions.get(file, []):
|
||||
if only_function_name == fn.function_name and (
|
||||
class_name is None or class_name == fn.top_level_parent_name
|
||||
):
|
||||
found_function = fn
|
||||
if found_function is None:
|
||||
msg = f"Function {only_function_name} not found in file {file} or the function does not have a 'return' statement or is a property"
|
||||
raise ValueError(msg)
|
||||
functions[file] = [found_function]
|
||||
else:
|
||||
logger.info("Finding all functions modified in the current git diff ...")
|
||||
ph("cli-optimizing-git-diff")
|
||||
functions = get_functions_within_git_diff()
|
||||
filtered_modified_functions, functions_count = filter_functions(
|
||||
functions, test_cfg.tests_root, ignore_paths, project_root, module_root
|
||||
)
|
||||
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
|
||||
return filtered_modified_functions, functions_count
|
||||
filtered_modified_functions, functions_count = filter_functions(
|
||||
functions, test_cfg.tests_root, ignore_paths, project_root, module_root
|
||||
)
|
||||
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
|
||||
return filtered_modified_functions, functions_count
|
||||
|
||||
|
||||
def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]:
|
||||
|
|
@ -245,7 +252,8 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
|
|||
try:
|
||||
ast_module = ast.parse(f.read())
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
if DEBUG_MODE:
|
||||
logger.exception(e)
|
||||
return functions
|
||||
function_name_visitor = FunctionWithReturnStatement(file_path)
|
||||
function_name_visitor.visit(ast_module)
|
||||
|
|
@ -309,7 +317,7 @@ def is_git_repo(file_path: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
@cache
|
||||
def ignored_submodule_paths(module_root: str) -> list[str]:
|
||||
if is_git_repo(module_root):
|
||||
git_repo = git.Repo(module_root, search_parent_directories=True)
|
||||
|
|
@ -473,6 +481,7 @@ def filter_functions(
|
|||
log_string: str
|
||||
if log_string := "\n".join([k for k, v in log_info.items() if v > 0]):
|
||||
logger.info(f"Ignoring: {log_string}")
|
||||
console.rule()
|
||||
return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count
|
||||
|
||||
|
||||
|
|
@ -492,11 +501,10 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
|
|||
return False
|
||||
if submodule_paths is None:
|
||||
submodule_paths = ignored_submodule_paths(module_root)
|
||||
if file_path in submodule_paths or any(
|
||||
file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths
|
||||
):
|
||||
return False
|
||||
return True
|
||||
return not (
|
||||
file_path in submodule_paths
|
||||
or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths)
|
||||
)
|
||||
|
||||
|
||||
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
|
||||
|
|
|
|||
|
|
@ -29,10 +29,27 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
|
|||
"""Check if the given name belongs to the specified function"""
|
||||
if name.full_name and name.full_name.startswith(name.module_name):
|
||||
subname: str = name.full_name.replace(name.module_name, "", 1)
|
||||
else:
|
||||
# The name is defined inside the function or is the function itself
|
||||
if f".{function_name}." in subname or f".{function_name}" == subname:
|
||||
return True
|
||||
return bool(name_in_listcomp_in_function(name, function_name))
|
||||
return False
|
||||
|
||||
|
||||
def name_in_listcomp_in_function(name: Name, function_name: str) -> bool:
|
||||
"""Check if the given name is in a list comprehension in the specified function
|
||||
Special case because jedi has a bug https://github.com/davidhalter/jedi/issues/1944
|
||||
"""
|
||||
try:
|
||||
parent_node = name._name.parent_context.tree_node.parent
|
||||
if hasattr(parent_node, "type") and parent_node.type == "testlist_comp":
|
||||
while parent_node := parent_node.parent:
|
||||
if parent_node.type == "funcdef":
|
||||
return parent_node.name.value == function_name
|
||||
return False
|
||||
except Exception:
|
||||
# don't want to handle conformance with 3rd party library private attribute access exception types
|
||||
return False
|
||||
# The name is defined inside the function or is the function itself
|
||||
return f".{function_name}." in subname or f".{function_name}" == subname
|
||||
|
||||
|
||||
def get_type_annotation_context(
|
||||
|
|
|
|||
|
|
@ -486,6 +486,7 @@ class Optimizer:
|
|||
logger.warning(
|
||||
"No functions were replaced in the optimized code. Skipping optimization candidate."
|
||||
)
|
||||
console.rule()
|
||||
continue
|
||||
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
|
||||
logger.error(e)
|
||||
|
|
|
|||
|
|
@ -89,8 +89,10 @@ def check_create_pr(
|
|||
)
|
||||
else:
|
||||
logger.info("Creating a new PR with the optimized code...")
|
||||
console.rule()
|
||||
owner, repo = get_repo_owner_and_name(git_repo, git_remote)
|
||||
logger.info(f"Pushing to {git_remote} - Owner: {owner}, Repo: {repo}")
|
||||
console.rule()
|
||||
if not check_and_push_branch(git_repo, wait_for_push=True):
|
||||
logger.warning("⏭️ Branch is not pushed, skipping PR creation...")
|
||||
return
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import tempfile
|
|||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.static_analysis import has_typed_parameters
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
|
@ -22,6 +22,7 @@ def generate_concolic_tests(
|
|||
concolic_test_suite_code = ""
|
||||
if test_cfg.concolic_test_root_dir and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents):
|
||||
logger.info("Generating concolic opcode coverage tests for the original code…")
|
||||
console.rule()
|
||||
cover_result = subprocess.run(
|
||||
[
|
||||
"crosshair",
|
||||
|
|
@ -63,6 +64,7 @@ def generate_concolic_tests(
|
|||
f"Created {num_discovered_concolic_tests} "
|
||||
f"concolic unit test case{'s' if num_discovered_concolic_tests != 1 else ''} "
|
||||
)
|
||||
console.rule()
|
||||
ph("cli-optimize-concolic-tests", {"num_tests": num_discovered_concolic_tests})
|
||||
|
||||
else:
|
||||
|
|
@ -71,4 +73,5 @@ def generate_concolic_tests(
|
|||
"Error running CrossHair Cover" f"{': ' + cover_result.stderr if cover_result.stderr else '.'}"
|
||||
)
|
||||
)
|
||||
console.rule()
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
|
|
|
|||
|
|
@ -3,22 +3,20 @@ from __future__ import annotations
|
|||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, IS_POSIX
|
||||
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME
|
||||
from codeflash.code_utils.coverage_utils import prepare_coverage_files
|
||||
from codeflash.models.models import CodeOptimizationContext, TestFiles
|
||||
from codeflash.models.models import TestFiles
|
||||
from codeflash.verification.test_results import TestType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.models.models import TestFiles
|
||||
|
||||
is_posix = os.name != "nt"
|
||||
|
||||
|
||||
def execute_test_subprocess(
|
||||
cmd_list: list[str], cwd: Path | None, env: dict[str, str] | None, timeout: int = 600
|
||||
|
|
@ -53,7 +51,7 @@ def run_tests(
|
|||
)
|
||||
else:
|
||||
test_files.append(str(file.instrumented_file_path))
|
||||
pytest_cmd_list = shlex.split(pytest_cmd, posix=is_posix)
|
||||
pytest_cmd_list = shlex.split(pytest_cmd, posix=IS_POSIX)
|
||||
|
||||
common_pytest_args = [
|
||||
"--capture=tee-sys",
|
||||
|
|
@ -62,6 +60,7 @@ def run_tests(
|
|||
f"--codeflash_seconds={pytest_target_runtime_seconds}",
|
||||
"--codeflash_loops_scope=session",
|
||||
]
|
||||
|
||||
pytest_test_env = test_env.copy()
|
||||
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
|
||||
|
||||
|
|
@ -73,7 +72,7 @@ def run_tests(
|
|||
coverage_args = ["--codeflash_min_loops=1", "--codeflash_max_loops=1"]
|
||||
|
||||
cov_erase = execute_test_subprocess(
|
||||
shlex.split(f"{sys.executable} -m coverage erase"), cwd=cwd, env=pytest_test_env
|
||||
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_test_env
|
||||
) # this cleanup is necessary to avoid coverage data from previous runs, if there are any, then the current run will be appended to the previous data, which skews the results
|
||||
logger.debug(cov_erase)
|
||||
|
||||
|
|
@ -84,7 +83,7 @@ def run_tests(
|
|||
]
|
||||
|
||||
cov_run = execute_test_subprocess(
|
||||
shlex.split(f"{sys.executable} -m coverage run --rcfile={coveragercfile} -m pytest")
|
||||
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage run --rcfile={coveragercfile.as_posix()} -m pytest")
|
||||
+ files
|
||||
+ common_pytest_args
|
||||
+ coverage_args
|
||||
|
|
@ -95,13 +94,18 @@ def run_tests(
|
|||
logger.debug(cov_run)
|
||||
|
||||
cov_report = execute_test_subprocess(
|
||||
shlex.split(f"{sys.executable} -m coverage json --rcfile={coveragercfile}"),
|
||||
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage json --rcfile={coveragercfile.as_posix()}"),
|
||||
cwd=cwd,
|
||||
env=pytest_test_env,
|
||||
) # this will generate a json file with the coverage data
|
||||
logger.debug(cov_report)
|
||||
if "No data to report." in cov_report.stdout:
|
||||
logger.warning("No coverage data to report. Check if the tests are running correctly.")
|
||||
console.rule()
|
||||
coverage_out_file = None
|
||||
|
||||
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
|
||||
result_args = [f"--junitxml={result_file_path}", "-o", "junit_logging=all"]
|
||||
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
|
||||
|
||||
results = execute_test_subprocess(
|
||||
pytest_cmd_list
|
||||
|
|
@ -115,7 +119,7 @@ def run_tests(
|
|||
)
|
||||
elif test_framework == "unittest":
|
||||
result_file_path = get_run_tmp_file(Path("unittest_results.xml"))
|
||||
unittest_cmd_list = [sys.executable, "-m", "xmlrunner"]
|
||||
unittest_cmd_list = [SAFE_SYS_EXECUTABLE, "-m", "xmlrunner"]
|
||||
log_level = ["-v"] if verbose else []
|
||||
files = [str(file.instrumented_file_path) for file in test_paths.test_files]
|
||||
output_file = ["--output-file", str(result_file_path)]
|
||||
|
|
@ -125,6 +129,7 @@ def run_tests(
|
|||
)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid test framework -- I only support Pytest and Unittest currently.")
|
||||
msg = "Invalid test framework -- I only support Pytest and Unittest currently."
|
||||
raise ValueError(msg)
|
||||
|
||||
return result_file_path, results, coverage_out_file if enable_coverage else None
|
||||
|
|
|
|||
|
|
@ -3,12 +3,11 @@ from argparse import Namespace
|
|||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
from returns.pipeline import is_successful
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.function_context import get_function_variables_definitions
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from returns.pipeline import is_successful
|
||||
|
||||
|
||||
def calculate_something(data):
|
||||
|
|
@ -123,6 +122,10 @@ def simple_function_with_one_dep_ann(data: MyData):
|
|||
return calculate_something_ann(data)
|
||||
|
||||
|
||||
def list_comprehension_dependency(data: MyData):
|
||||
return [calculate_something(data) for x in range(10)]
|
||||
|
||||
|
||||
def test_simple_dependencies_ann() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
helper_functions = get_function_variables_definitions(
|
||||
|
|
@ -305,3 +308,13 @@ def test_recursive_function_context() -> None:
|
|||
return self.recursive(num) + num_1
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def test_list_comprehension_dependency() -> None:
|
||||
file_path = pathlib.Path(__file__).resolve()
|
||||
helper_functions = get_function_variables_definitions(
|
||||
FunctionToOptimize("list_comprehension_dependency", str(file_path), []), str(file_path.parent.resolve())
|
||||
)[0]
|
||||
assert len(helper_functions) == 2
|
||||
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.MyData"
|
||||
assert helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something"
|
||||
|
|
|
|||
Loading…
Reference in a new issue