Merge branch 'main' into update-codeflash-yaml-concurrency

This commit is contained in:
Saurabh Misra 2024-12-08 13:25:24 -08:00 committed by GitHub
commit 33b2eadc8b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 128 additions and 69 deletions

View file

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, TypeVar
import libcst as cst
from codeflash.cli_cmds.console import logger
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from codeflash.models.models import FunctionParent
@ -184,6 +184,7 @@ def replace_functions_in_file(
if visitor.optim_body is None:
msg = f"Unable to find function {function_name} in optimized code. Returning unchanged source code."
logger.error(msg)
console.rule()
return source_code
transformer = OptimFunctionReplacer(

View file

@ -1,7 +1,14 @@
import os
import sys
from pathlib import Path
# os-independent newline
# important for any user-facing output or files we write
# make sure to use this in f-strings e.g. f"some string{LF}"
# you can use "[^f]\".*\{LF\}\" to find any lines in your code that use this without the f-string
LF: str = os.linesep
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
IS_POSIX = os.name != "nt"

View file

@ -14,8 +14,9 @@ import jedi
from pydantic.dataclasses import dataclass
from pytest import ExitCode
from codeflash.cli_cmds.console import logger
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile
from codeflash.verification.test_results import TestType
@ -51,7 +52,7 @@ def discover_tests_pytest(
tmp_pickle_path = get_run_tmp_file("collected_tests.pkl")
subprocess.run(
[
sys.executable,
SAFE_SYS_EXECUTABLE,
Path(__file__).parent / "pytest_new_process_discovery.py",
str(project_root),
str(tests_root),
@ -76,6 +77,7 @@ def discover_tests_pytest(
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}={ExitCode(exitcode).name}")
else:
logger.warning(f"Failed to collect tests. Pytest Exit code: {exitcode}")
console.rule()
else:
logger.debug(f"Pytest collection exit code: {exitcode}")
if pytest_rootdir is not None:

View file

@ -3,9 +3,10 @@ from __future__ import annotations
import ast
import os
import random
import warnings
from _ast import AsyncFunctionDef, ClassDef, FunctionDef
from collections import defaultdict
from functools import lru_cache
from functools import cache
from pathlib import Path
from typing import TYPE_CHECKING, Optional
@ -14,7 +15,7 @@ import libcst as cst
from pydantic.dataclasses import dataclass
from codeflash.api.cfapi import get_blocklisted_functions
from codeflash.cli_cmds.console import console, logger
from codeflash.cli_cmds.console import DEBUG_MODE, console, logger
from codeflash.code_utils.code_utils import (
is_class_defined_in_file,
module_name_from_file_path,
@ -159,46 +160,52 @@ def get_functions_to_optimize(
sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1
), "Only one of optimize_all, replay_test, or file should be provided"
functions: dict[str, list[FunctionToOptimize]]
if optimize_all:
logger.info("Finding all functions in the module '%s'", optimize_all)
functions = get_all_files_and_functions(Path(optimize_all))
elif replay_test is not None:
functions = get_all_replay_test_functions(
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=SyntaxWarning)
if optimize_all:
logger.info("Finding all functions in the module '%s'", optimize_all)
console.rule()
functions = get_all_files_and_functions(Path(optimize_all))
elif replay_test is not None:
functions = get_all_replay_test_functions(
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
)
elif file is not None:
logger.info("Finding all functions in the file '%s'", file)
console.rule()
functions = find_all_functions_in_file(file)
if only_get_this_function is not None:
split_function = only_get_this_function.split(".")
if len(split_function) > 2:
msg = "Function name should be in the format 'function_name' or 'class_name.function_name'"
raise ValueError(msg)
if len(split_function) == 2:
class_name, only_function_name = split_function
else:
class_name = None
only_function_name = split_function[0]
found_function = None
for fn in functions.get(file, []):
if only_function_name == fn.function_name and (
class_name is None or class_name == fn.top_level_parent_name
):
found_function = fn
if found_function is None:
msg = f"Function {only_function_name} not found in file {file} or the function does not have a 'return' statement or is a property"
raise ValueError(msg)
functions[file] = [found_function]
else:
logger.info("Finding all functions modified in the current git diff ...")
ph("cli-optimizing-git-diff")
functions = get_functions_within_git_diff()
filtered_modified_functions, functions_count = filter_functions(
functions, test_cfg.tests_root, ignore_paths, project_root, module_root
)
elif file is not None:
logger.info("Finding all functions in the file '%s'", file)
console.rule()
functions = find_all_functions_in_file(file)
if only_get_this_function is not None:
split_function = only_get_this_function.split(".")
if len(split_function) > 2:
msg = "Function name should be in the format 'function_name' or 'class_name.function_name'"
raise ValueError(msg)
if len(split_function) == 2:
class_name, only_function_name = split_function
else:
class_name = None
only_function_name = split_function[0]
found_function = None
for fn in functions.get(file, []):
if only_function_name == fn.function_name and (
class_name is None or class_name == fn.top_level_parent_name
):
found_function = fn
if found_function is None:
msg = f"Function {only_function_name} not found in file {file} or the function does not have a 'return' statement or is a property"
raise ValueError(msg)
functions[file] = [found_function]
else:
logger.info("Finding all functions modified in the current git diff ...")
ph("cli-optimizing-git-diff")
functions = get_functions_within_git_diff()
filtered_modified_functions, functions_count = filter_functions(
functions, test_cfg.tests_root, ignore_paths, project_root, module_root
)
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
return filtered_modified_functions, functions_count
filtered_modified_functions, functions_count = filter_functions(
functions, test_cfg.tests_root, ignore_paths, project_root, module_root
)
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
return filtered_modified_functions, functions_count
def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]:
@ -245,7 +252,8 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
try:
ast_module = ast.parse(f.read())
except Exception as e:
logger.exception(e)
if DEBUG_MODE:
logger.exception(e)
return functions
function_name_visitor = FunctionWithReturnStatement(file_path)
function_name_visitor.visit(ast_module)
@ -309,7 +317,7 @@ def is_git_repo(file_path: str) -> bool:
return False
@lru_cache(maxsize=None)
@cache
def ignored_submodule_paths(module_root: str) -> list[str]:
if is_git_repo(module_root):
git_repo = git.Repo(module_root, search_parent_directories=True)
@ -473,6 +481,7 @@ def filter_functions(
log_string: str
if log_string := "\n".join([k for k, v in log_info.items() if v > 0]):
logger.info(f"Ignoring: {log_string}")
console.rule()
return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count
@ -492,11 +501,10 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
return False
if submodule_paths is None:
submodule_paths = ignored_submodule_paths(module_root)
if file_path in submodule_paths or any(
file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths
):
return False
return True
return not (
file_path in submodule_paths
or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths)
)
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:

View file

@ -29,10 +29,27 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
"""Check if the given name belongs to the specified function"""
if name.full_name and name.full_name.startswith(name.module_name):
subname: str = name.full_name.replace(name.module_name, "", 1)
else:
# The name is defined inside the function or is the function itself
if f".{function_name}." in subname or f".{function_name}" == subname:
return True
return bool(name_in_listcomp_in_function(name, function_name))
return False
def name_in_listcomp_in_function(name: Name, function_name: str) -> bool:
"""Check if the given name is in a list comprehension in the specified function
Special case because jedi has a bug https://github.com/davidhalter/jedi/issues/1944
"""
try:
parent_node = name._name.parent_context.tree_node.parent
if hasattr(parent_node, "type") and parent_node.type == "testlist_comp":
while parent_node := parent_node.parent:
if parent_node.type == "funcdef":
return parent_node.name.value == function_name
return False
except Exception:
# don't want to handle conformance with 3rd party library private attribute access exception types
return False
# The name is defined inside the function or is the function itself
return f".{function_name}." in subname or f".{function_name}" == subname
def get_type_annotation_context(

View file

@ -486,6 +486,7 @@ class Optimizer:
logger.warning(
"No functions were replaced in the optimized code. Skipping optimization candidate."
)
console.rule()
continue
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
logger.error(e)

View file

@ -89,8 +89,10 @@ def check_create_pr(
)
else:
logger.info("Creating a new PR with the optimized code...")
console.rule()
owner, repo = get_repo_owner_and_name(git_repo, git_remote)
logger.info(f"Pushing to {git_remote} - Owner: {owner}, Repo: {repo}")
console.rule()
if not check_and_push_branch(git_repo, wait_for_push=True):
logger.warning("⏭️ Branch is not pushed, skipping PR creation...")
return

View file

@ -6,7 +6,7 @@ import tempfile
from argparse import Namespace
from pathlib import Path
from codeflash.cli_cmds.console import logger
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.static_analysis import has_typed_parameters
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@ -22,6 +22,7 @@ def generate_concolic_tests(
concolic_test_suite_code = ""
if test_cfg.concolic_test_root_dir and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents):
logger.info("Generating concolic opcode coverage tests for the original code…")
console.rule()
cover_result = subprocess.run(
[
"crosshair",
@ -63,6 +64,7 @@ def generate_concolic_tests(
f"Created {num_discovered_concolic_tests} "
f"concolic unit test case{'s' if num_discovered_concolic_tests != 1 else ''} "
)
console.rule()
ph("cli-optimize-concolic-tests", {"num_tests": num_discovered_concolic_tests})
else:
@ -71,4 +73,5 @@ def generate_concolic_tests(
"Error running CrossHair Cover" f"{': ' + cover_result.stderr if cover_result.stderr else '.'}"
)
)
console.rule()
return function_to_concolic_tests, concolic_test_suite_code

View file

@ -3,22 +3,20 @@ from __future__ import annotations
import os
import shlex
import subprocess
import sys
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import logger
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, IS_POSIX
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME
from codeflash.code_utils.coverage_utils import prepare_coverage_files
from codeflash.models.models import CodeOptimizationContext, TestFiles
from codeflash.models.models import TestFiles
from codeflash.verification.test_results import TestType
if TYPE_CHECKING:
from codeflash.models.models import TestFiles
is_posix = os.name != "nt"
def execute_test_subprocess(
cmd_list: list[str], cwd: Path | None, env: dict[str, str] | None, timeout: int = 600
@ -53,7 +51,7 @@ def run_tests(
)
else:
test_files.append(str(file.instrumented_file_path))
pytest_cmd_list = shlex.split(pytest_cmd, posix=is_posix)
pytest_cmd_list = shlex.split(pytest_cmd, posix=IS_POSIX)
common_pytest_args = [
"--capture=tee-sys",
@ -62,6 +60,7 @@ def run_tests(
f"--codeflash_seconds={pytest_target_runtime_seconds}",
"--codeflash_loops_scope=session",
]
pytest_test_env = test_env.copy()
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
@ -73,7 +72,7 @@ def run_tests(
coverage_args = ["--codeflash_min_loops=1", "--codeflash_max_loops=1"]
cov_erase = execute_test_subprocess(
shlex.split(f"{sys.executable} -m coverage erase"), cwd=cwd, env=pytest_test_env
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_test_env
) # this cleanup is necessary to avoid coverage data from previous runs, if there are any, then the current run will be appended to the previous data, which skews the results
logger.debug(cov_erase)
@ -84,7 +83,7 @@ def run_tests(
]
cov_run = execute_test_subprocess(
shlex.split(f"{sys.executable} -m coverage run --rcfile={coveragercfile} -m pytest")
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage run --rcfile={coveragercfile.as_posix()} -m pytest")
+ files
+ common_pytest_args
+ coverage_args
@ -95,13 +94,18 @@ def run_tests(
logger.debug(cov_run)
cov_report = execute_test_subprocess(
shlex.split(f"{sys.executable} -m coverage json --rcfile={coveragercfile}"),
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage json --rcfile={coveragercfile.as_posix()}"),
cwd=cwd,
env=pytest_test_env,
) # this will generate a json file with the coverage data
logger.debug(cov_report)
if "No data to report." in cov_report.stdout:
logger.warning("No coverage data to report. Check if the tests are running correctly.")
console.rule()
coverage_out_file = None
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
result_args = [f"--junitxml={result_file_path}", "-o", "junit_logging=all"]
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
results = execute_test_subprocess(
pytest_cmd_list
@ -115,7 +119,7 @@ def run_tests(
)
elif test_framework == "unittest":
result_file_path = get_run_tmp_file(Path("unittest_results.xml"))
unittest_cmd_list = [sys.executable, "-m", "xmlrunner"]
unittest_cmd_list = [SAFE_SYS_EXECUTABLE, "-m", "xmlrunner"]
log_level = ["-v"] if verbose else []
files = [str(file.instrumented_file_path) for file in test_paths.test_files]
output_file = ["--output-file", str(result_file_path)]
@ -125,6 +129,7 @@ def run_tests(
)
else:
raise ValueError("Invalid test framework -- I only support Pytest and Unittest currently.")
msg = "Invalid test framework -- I only support Pytest and Unittest currently."
raise ValueError(msg)
return result_file_path, results, coverage_out_file if enable_coverage else None

View file

@ -3,12 +3,11 @@ from argparse import Namespace
from dataclasses import dataclass
import pytest
from returns.pipeline import is_successful
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.optimization.function_context import get_function_variables_definitions
from codeflash.optimization.optimizer import Optimizer
from returns.pipeline import is_successful
def calculate_something(data):
@ -123,6 +122,10 @@ def simple_function_with_one_dep_ann(data: MyData):
return calculate_something_ann(data)
def list_comprehension_dependency(data: MyData):
return [calculate_something(data) for x in range(10)]
def test_simple_dependencies_ann() -> None:
file_path = pathlib.Path(__file__).resolve()
helper_functions = get_function_variables_definitions(
@ -305,3 +308,13 @@ def test_recursive_function_context() -> None:
return self.recursive(num) + num_1
"""
)
def test_list_comprehension_dependency() -> None:
file_path = pathlib.Path(__file__).resolve()
helper_functions = get_function_variables_definitions(
FunctionToOptimize("list_comprehension_dependency", str(file_path), []), str(file_path.parent.resolve())
)[0]
assert len(helper_functions) == 2
assert helper_functions[0].jedi_definition.full_name == "test_function_dependencies.MyData"
assert helper_functions[1].jedi_definition.full_name == "test_function_dependencies.calculate_something"