mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
wrapped
This commit is contained in:
parent
7a6b204922
commit
1d3aeb997d
12 changed files with 35 additions and 40 deletions
|
|
@ -722,7 +722,6 @@ class AiServiceClient:
|
|||
function_trace_id: str,
|
||||
coverage_message: str,
|
||||
replay_tests: str,
|
||||
concolic_tests: str,
|
||||
calling_fn_details: str,
|
||||
) -> OptimizationReviewResult:
|
||||
"""Compute the optimization review of current Pull Request.
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
|
||||
# Pytest hooks
|
||||
@pytest.hookimpl
|
||||
def pytest_sessionfinish(self, session, exitstatus) -> None:
|
||||
def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001
|
||||
"""Execute after whole test run is completed."""
|
||||
# Write any remaining benchmark timings to the database
|
||||
codeflash_trace.close()
|
||||
|
|
@ -218,7 +218,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
|
||||
for item in items:
|
||||
# Check for direct benchmark fixture usage
|
||||
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames
|
||||
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames # ty:ignore[unsupported-operator]
|
||||
|
||||
# Check for @pytest.mark.benchmark marker
|
||||
has_marker = False
|
||||
|
|
@ -236,7 +236,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
def __init__(self, request: pytest.FixtureRequest) -> None:
|
||||
self.request = request
|
||||
|
||||
def __call__(self, func, *args, **kwargs): # type: ignore # noqa: ANN002, ANN003, ANN204, PGH003
|
||||
def __call__(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN204
|
||||
"""Handle both direct function calls and decorator usage."""
|
||||
if args or kwargs:
|
||||
# Used as benchmark(func, *args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -44,14 +44,14 @@ class GlobalFunctionCollector(cst.CSTVisitor):
|
|||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
|
||||
self.scope_depth -= 1
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
|
||||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
|
||||
self.scope_depth -= 1
|
||||
|
||||
|
||||
|
|
@ -65,7 +65,7 @@ class GlobalFunctionTransformer(cst.CSTTransformer):
|
|||
self.processed_functions: set[str] = set()
|
||||
self.scope_depth = 0
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
|
|
@ -80,14 +80,14 @@ class GlobalFunctionTransformer(cst.CSTTransformer):
|
|||
return self.new_functions[name]
|
||||
return updated_node
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
|
||||
self.scope_depth -= 1
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# Add any new functions that weren't in the original file
|
||||
new_statements = list(updated_node.body)
|
||||
|
||||
|
|
@ -370,7 +370,7 @@ class GlobalStatementTransformer(cst.CSTTransformer):
|
|||
super().__init__()
|
||||
self.global_statements = global_statements
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
if not self.global_statements:
|
||||
return updated_node
|
||||
|
||||
|
|
@ -1553,10 +1553,7 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo
|
|||
|
||||
# If numba is not installed and all modules used require numba for optimization,
|
||||
# return False since we can't optimize this code
|
||||
if not has_numba and modules_used.issubset(NUMBA_REQUIRED_MODULES):
|
||||
return False
|
||||
|
||||
return True
|
||||
return not (not has_numba and modules_used.issubset(NUMBA_REQUIRED_MODULES))
|
||||
|
||||
|
||||
def get_opt_review_metrics(
|
||||
|
|
|
|||
|
|
@ -690,15 +690,14 @@ def detect_frameworks_from_code(code: str) -> dict[str, str]:
|
|||
frameworks["tensorflow"] = alias.asname if alias.asname else module_name
|
||||
elif module_name == "jax":
|
||||
frameworks["jax"] = alias.asname if alias.asname else module_name
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module:
|
||||
module_name = node.module.split(".")[0]
|
||||
if module_name == "torch" and "torch" not in frameworks:
|
||||
frameworks["torch"] = module_name
|
||||
elif module_name == "tensorflow" and "tensorflow" not in frameworks:
|
||||
frameworks["tensorflow"] = module_name
|
||||
elif module_name == "jax" and "jax" not in frameworks:
|
||||
frameworks["jax"] = module_name
|
||||
elif isinstance(node, ast.ImportFrom) and node.module:
|
||||
module_name = node.module.split(".")[0]
|
||||
if module_name == "torch" and "torch" not in frameworks:
|
||||
frameworks["torch"] = module_name
|
||||
elif module_name == "tensorflow" and "tensorflow" not in frameworks:
|
||||
frameworks["tensorflow"] = module_name
|
||||
elif module_name == "jax" and "jax" not in frameworks:
|
||||
frameworks["jax"] = module_name
|
||||
|
||||
return frameworks
|
||||
|
||||
|
|
|
|||
|
|
@ -30,21 +30,21 @@ def main() -> None:
|
|||
if args.config_file and Path.exists(args.config_file):
|
||||
pyproject_config, _ = parse_config_file(args.config_file)
|
||||
disable_telemetry = pyproject_config.get("disable_telemetry", False)
|
||||
init_sentry(not disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(not disable_telemetry)
|
||||
init_sentry(enabled=not disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(enabled=not disable_telemetry)
|
||||
args.func()
|
||||
elif args.verify_setup:
|
||||
args = process_pyproject_config(args)
|
||||
init_sentry(not args.disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(not args.disable_telemetry)
|
||||
init_sentry(enabled=not args.disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
|
||||
ask_run_end_to_end_test(args)
|
||||
else:
|
||||
args = process_pyproject_config(args)
|
||||
if not env_utils.check_formatter_installed(args.formatter_cmds):
|
||||
return
|
||||
args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args)
|
||||
init_sentry(not args.disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(not args.disable_telemetry)
|
||||
init_sentry(enabled=not args.disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
|
||||
|
||||
from codeflash.optimization import optimizer
|
||||
|
||||
|
|
|
|||
|
|
@ -2486,8 +2486,6 @@ class FunctionOptimizer:
|
|||
pytest_cmd=self.test_cfg.pytest_cmd,
|
||||
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||
pytest_target_runtime_seconds=testing_time,
|
||||
pytest_min_loops=1,
|
||||
pytest_max_loops=1,
|
||||
test_framework=self.test_cfg.test_framework,
|
||||
)
|
||||
elif testing_type == TestingMode.PERFORMANCE:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from codeflash.version import __version__
|
|||
_posthog = None
|
||||
|
||||
|
||||
def initialize_posthog(enabled: bool = True) -> None:
|
||||
def initialize_posthog(*, enabled: bool = True) -> None:
|
||||
"""Enable or disable PostHog.
|
||||
|
||||
:param enabled: Whether to enable PostHog.
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sentry_sdk
|
|||
from sentry_sdk.integrations.logging import LoggingIntegration
|
||||
|
||||
|
||||
def init_sentry(enabled: bool = False, exclude_errors: bool = False) -> None:
|
||||
def init_sentry(*, enabled: bool = False, exclude_errors: bool = False) -> None:
|
||||
if enabled:
|
||||
sentry_logging = LoggingIntegration(
|
||||
level=logging.INFO, # Capture info and above as breadcrumbs
|
||||
|
|
|
|||
|
|
@ -227,8 +227,8 @@ def main(args: Namespace | None = None) -> ArgumentParser:
|
|||
|
||||
args = process_pyproject_config(args)
|
||||
args.previous_checkpoint_functions = None
|
||||
init_sentry(not args.disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(not args.disable_telemetry)
|
||||
init_sentry(enabled=not args.disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
|
||||
|
||||
from codeflash.optimization import optimizer
|
||||
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ class Tracer:
|
|||
config: dict,
|
||||
result_pickle_file_path: Path,
|
||||
functions: list[str] | None = None,
|
||||
*,
|
||||
disable: bool = False,
|
||||
project_root: Path | None = None,
|
||||
max_function_count: int = 256,
|
||||
|
|
|
|||
|
|
@ -54,11 +54,11 @@ def instrument_codeflash_capture(
|
|||
|
||||
|
||||
def add_codeflash_capture_to_init(
|
||||
target_classes: set[str], fto_name: str, tmp_dir_path: str, code: str, tests_root: Path, is_fto: bool = False
|
||||
target_classes: set[str], fto_name: str, tmp_dir_path: str, code: str, tests_root: Path, *, is_fto: bool = False
|
||||
) -> str:
|
||||
"""Add codeflash_capture decorator to __init__ function in the specified class."""
|
||||
tree = ast.parse(code)
|
||||
transformer = InitDecorator(target_classes, fto_name, tmp_dir_path, tests_root, is_fto)
|
||||
transformer = InitDecorator(target_classes, fto_name, tmp_dir_path, tests_root, is_fto=is_fto)
|
||||
modified_tree = transformer.visit(tree)
|
||||
if transformer.inserted_decorator:
|
||||
ast.fix_missing_locations(modified_tree)
|
||||
|
|
@ -71,7 +71,7 @@ class InitDecorator(ast.NodeTransformer):
|
|||
"""AST transformer that adds codeflash_capture decorator to specific class's __init__."""
|
||||
|
||||
def __init__(
|
||||
self, target_classes: set[str], fto_name: str, tmp_dir_path: str, tests_root: Path, is_fto=False
|
||||
self, target_classes: set[str], fto_name: str, tmp_dir_path: str, tests_root: Path, *, is_fto: bool = False
|
||||
) -> None:
|
||||
self.target_classes = target_classes
|
||||
self.fto_name = fto_name
|
||||
|
|
|
|||
|
|
@ -257,6 +257,7 @@ ignore = [
|
|||
"PLC0415",
|
||||
"UP045",
|
||||
"S110", # try-except-pass - we do this a lot
|
||||
"ARG002",
|
||||
]
|
||||
|
||||
[tool.ruff.lint.flake8-type-checking]
|
||||
|
|
|
|||
Loading…
Reference in a new issue