Merge branch 'main' into multi-language
This commit is contained in:
commit
84d0b1cf09
45 changed files with 156 additions and 194 deletions
|
|
@ -298,7 +298,7 @@ class AiServiceClient:
|
|||
line_profiler_results: str,
|
||||
n_candidates: int,
|
||||
experiment_metadata: ExperimentMetadata | None = None,
|
||||
is_numerical_code: bool | None = None, # noqa: FBT001
|
||||
is_numerical_code: bool | None = None,
|
||||
language: str = "python",
|
||||
language_version: str | None = None,
|
||||
) -> list[OptimizedCandidate]:
|
||||
|
|
@ -814,7 +814,7 @@ class AiServiceClient:
|
|||
error = response.json()["error"]
|
||||
logger.error(f"Error generating tests: {response.status_code} - {error}")
|
||||
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": error})
|
||||
return None # noqa: TRY300
|
||||
return None
|
||||
except Exception:
|
||||
logger.error(f"Error generating tests: {response.status_code} - {response.text}")
|
||||
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})
|
||||
|
|
@ -830,7 +830,6 @@ class AiServiceClient:
|
|||
function_trace_id: str,
|
||||
coverage_message: str,
|
||||
replay_tests: str,
|
||||
concolic_tests: str, # noqa: ARG002
|
||||
calling_fn_details: str,
|
||||
language: str = "python",
|
||||
) -> OptimizationReviewResult:
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ def make_cfapi_request(
|
|||
else:
|
||||
response = requests.get(url, headers=cfapi_headers, params=params, timeout=60)
|
||||
response.raise_for_status()
|
||||
return response # noqa: TRY300
|
||||
return response
|
||||
except requests.exceptions.HTTPError:
|
||||
# response may be either a string or JSON, so we handle both cases
|
||||
error_message = ""
|
||||
|
|
@ -102,7 +102,7 @@ def make_cfapi_request(
|
|||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_user_id(api_key: Optional[str] = None) -> Optional[str]: # noqa: PLR0911
|
||||
def get_user_id(api_key: Optional[str] = None) -> Optional[str]:
|
||||
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
|
||||
|
||||
:param api_key: The API key to use. If None, uses get_codeflash_api_key().
|
||||
|
|
@ -396,7 +396,7 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]:
|
|||
|
||||
def is_function_being_optimized_again(
|
||||
owner: str, repo: str, pr_number: int, code_contexts: list[dict[str, str]]
|
||||
) -> Any: # noqa: ANN401
|
||||
) -> Any:
|
||||
"""Check if the function being optimized is being optimized again."""
|
||||
response = make_cfapi_request(
|
||||
"/is-already-optimized",
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class CodeflashTrace:
|
|||
func_id = (func.__module__, func.__name__)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
|
||||
def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003
|
||||
# Initialize thread-local active functions set if it doesn't exist
|
||||
if not hasattr(self._thread_local, "active_functions"):
|
||||
self._thread_local.active_functions = set()
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class AddDecoratorTransformer(cst.CSTTransformer):
|
|||
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# Create import statement for codeflash_trace
|
||||
if not self.added_codeflash_trace:
|
||||
return updated_node
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
|
||||
# Pytest hooks
|
||||
@pytest.hookimpl
|
||||
def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, ARG002
|
||||
def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001
|
||||
"""Execute after whole test run is completed."""
|
||||
# Write any remaining benchmark timings to the database
|
||||
codeflash_trace.close()
|
||||
|
|
@ -218,7 +218,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
|
||||
for item in items:
|
||||
# Check for direct benchmark fixture usage
|
||||
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames
|
||||
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames # ty:ignore[unsupported-operator]
|
||||
|
||||
# Check for @pytest.mark.benchmark marker
|
||||
has_marker = False
|
||||
|
|
@ -249,7 +249,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
self._run_benchmark(func)
|
||||
return wrapped_func
|
||||
|
||||
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN202
|
||||
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN002, ANN003, ANN202
|
||||
"""Actual benchmark implementation."""
|
||||
node_path = getattr(self.request.node, "path", None) or getattr(self.request.node, "fspath", None)
|
||||
if node_path is None:
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def inquirer_wrapper(func: Callable[..., str | bool], *args: str | bool, **kwarg
|
|||
return func(*new_args, **new_kwargs)
|
||||
|
||||
|
||||
def split_string_to_cli_width(string: str, is_confirm: bool = False) -> list[str]: # noqa: FBT001, FBT002
|
||||
def split_string_to_cli_width(string: str, is_confirm: bool = False) -> list[str]:
|
||||
cli_width, _ = shutil.get_terminal_size()
|
||||
# split string to lines that accommodate "[?] " prefix
|
||||
cli_width -= len("[?] ")
|
||||
|
|
|
|||
|
|
@ -687,7 +687,7 @@ def create_empty_pyproject_toml(pyproject_toml_path: Path) -> None:
|
|||
apologize_and_exit()
|
||||
|
||||
|
||||
def install_github_actions(override_formatter_check: bool = False) -> None: # noqa: FBT001, FBT002
|
||||
def install_github_actions(override_formatter_check: bool = False) -> None:
|
||||
try:
|
||||
config, _config_file_path = parse_config_file(override_formatter_check=override_formatter_check)
|
||||
|
||||
|
|
@ -1121,11 +1121,12 @@ def install_github_actions(override_formatter_check: bool = False) -> None: # n
|
|||
apologize_and_exit()
|
||||
|
||||
|
||||
def determine_dependency_manager(pyproject_data: dict[str, Any]) -> DependencyManager: # noqa: PLR0911
|
||||
def determine_dependency_manager(pyproject_data: dict[str, Any]) -> DependencyManager:
|
||||
"""Determine which dependency manager is being used based on pyproject.toml contents."""
|
||||
if (Path.cwd() / "poetry.lock").exists():
|
||||
cwd = Path.cwd()
|
||||
if (cwd / "poetry.lock").exists():
|
||||
return DependencyManager.POETRY
|
||||
if (Path.cwd() / "uv.lock").exists():
|
||||
if (cwd / "uv.lock").exists():
|
||||
return DependencyManager.UV
|
||||
if "tool" not in pyproject_data:
|
||||
return DependencyManager.PIP
|
||||
|
|
@ -1325,10 +1326,7 @@ def collect_repo_files_for_workflow(git_root: Path) -> dict[str, Any]:
|
|||
|
||||
|
||||
def generate_dynamic_workflow_content(
|
||||
optimize_yml_content: str,
|
||||
config: tuple[dict[str, Any], Path],
|
||||
git_root: Path,
|
||||
benchmark_mode: bool = False, # noqa: FBT001, FBT002
|
||||
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path, benchmark_mode: bool = False
|
||||
) -> str:
|
||||
"""Generate workflow content with dynamic steps from AI service, falling back to static template.
|
||||
|
||||
|
|
@ -1460,10 +1458,7 @@ def generate_dynamic_workflow_content(
|
|||
|
||||
|
||||
def customize_codeflash_yaml_content(
|
||||
optimize_yml_content: str,
|
||||
config: tuple[dict[str, Any], Path],
|
||||
git_root: Path,
|
||||
benchmark_mode: bool = False, # noqa: FBT001, FBT002
|
||||
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path, benchmark_mode: bool = False
|
||||
) -> str:
|
||||
module_path = str(Path(config["module_root"]).relative_to(git_root) / "**")
|
||||
optimize_yml_content = optimize_yml_content.replace("{{ codeflash_module_path }}", module_path)
|
||||
|
|
|
|||
|
|
@ -45,14 +45,14 @@ class GlobalFunctionCollector(cst.CSTVisitor):
|
|||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
|
||||
self.scope_depth -= 1
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
|
||||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
|
||||
self.scope_depth -= 1
|
||||
|
||||
|
||||
|
|
@ -66,7 +66,7 @@ class GlobalFunctionTransformer(cst.CSTTransformer):
|
|||
self.processed_functions: set[str] = set()
|
||||
self.scope_depth = 0
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
|
|
@ -81,14 +81,14 @@ class GlobalFunctionTransformer(cst.CSTTransformer):
|
|||
return self.new_functions[name]
|
||||
return updated_node
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
|
||||
self.scope_depth -= 1
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# Add any new functions that weren't in the original file
|
||||
new_statements = list(updated_node.body)
|
||||
|
||||
|
|
@ -142,28 +142,28 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
|
|||
self.scope_depth = 0
|
||||
self.if_else_depth = 0
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: # noqa: ARG002
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
|
||||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
|
||||
self.scope_depth -= 1
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
|
||||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
|
||||
self.scope_depth -= 1
|
||||
|
||||
def visit_If(self, node: cst.If) -> Optional[bool]: # noqa: ARG002
|
||||
def visit_If(self, node: cst.If) -> Optional[bool]:
|
||||
self.if_else_depth += 1
|
||||
return True
|
||||
|
||||
def leave_If(self, original_node: cst.If) -> None: # noqa: ARG002
|
||||
def leave_If(self, original_node: cst.If) -> None:
|
||||
self.if_else_depth -= 1
|
||||
|
||||
def visit_Else(self, node: cst.Else) -> Optional[bool]: # noqa: ARG002
|
||||
def visit_Else(self, node: cst.Else) -> Optional[bool]:
|
||||
# Else blocks are already counted as part of the if statement
|
||||
return True
|
||||
|
||||
|
|
@ -232,24 +232,24 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
self.scope_depth = 0
|
||||
self.if_else_depth = 0
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
self.scope_depth -= 1
|
||||
return updated_node
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
|
||||
self.scope_depth -= 1
|
||||
return updated_node
|
||||
|
||||
def visit_If(self, node: cst.If) -> None: # noqa: ARG002
|
||||
def visit_If(self, node: cst.If) -> None:
|
||||
self.if_else_depth += 1
|
||||
|
||||
def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If: # noqa: ARG002
|
||||
def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
|
||||
self.if_else_depth -= 1
|
||||
return updated_node
|
||||
|
||||
|
|
@ -284,7 +284,7 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
|||
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# Add any new assignments that weren't in the original file
|
||||
new_statements = list(updated_node.body)
|
||||
|
||||
|
|
@ -371,7 +371,7 @@ class GlobalStatementTransformer(cst.CSTTransformer):
|
|||
super().__init__()
|
||||
self.global_statements = global_statements
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
if not self.global_statements:
|
||||
return updated_node
|
||||
|
||||
|
|
@ -397,20 +397,20 @@ class GlobalStatementCollector(cst.CSTVisitor):
|
|||
self.global_statements = []
|
||||
self.in_function_or_class = False
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool: # noqa: ARG002
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
# Don't visit inside classes
|
||||
self.in_function_or_class = True
|
||||
return False
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
|
||||
self.in_function_or_class = False
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # noqa: ARG002
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
|
||||
# Don't visit inside functions
|
||||
self.in_function_or_class = True
|
||||
return False
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
|
||||
self.in_function_or_class = False
|
||||
|
||||
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
|
||||
|
|
@ -491,16 +491,16 @@ class DottedImportCollector(cst.CSTVisitor):
|
|||
self.depth = 0
|
||||
self._collect_imports_from_block(node)
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
self.depth += 1
|
||||
|
||||
def leave_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
self.depth -= 1
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
self.depth += 1
|
||||
|
||||
def leave_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def leave_ClassDef(self, node: cst.ClassDef) -> None:
|
||||
self.depth -= 1
|
||||
|
||||
def visit_If(self, node: cst.If) -> None:
|
||||
|
|
@ -530,9 +530,7 @@ def find_last_import_line(target_code: str) -> int:
|
|||
|
||||
class FutureAliasedImportTransformer(cst.CSTTransformer):
|
||||
def leave_ImportFrom(
|
||||
self,
|
||||
original_node: cst.ImportFrom, # noqa: ARG002
|
||||
updated_node: cst.ImportFrom,
|
||||
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
|
||||
) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel:
|
||||
import libcst.matchers as m
|
||||
|
||||
|
|
@ -677,7 +675,7 @@ def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
|
|||
if not name.startswith("_"):
|
||||
public_names.add(name)
|
||||
|
||||
return public_names # noqa: TRY300
|
||||
return public_names
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error resolving star import for {module_name}: {e}")
|
||||
|
|
@ -1166,7 +1164,7 @@ class FunctionCallFinder(ast.NodeVisitor):
|
|||
|
||||
return False
|
||||
|
||||
def _get_call_name(self, func_node) -> Optional[str]: # noqa: ANN001
|
||||
def _get_call_name(self, func_node) -> Optional[str]:
|
||||
"""Extract the name being called from a function node."""
|
||||
# Fast path short-circuit for ast.Name nodes
|
||||
if isinstance(func_node, ast.Name):
|
||||
|
|
@ -1559,10 +1557,7 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo
|
|||
|
||||
# If numba is not installed and all modules used require numba for optimization,
|
||||
# return False since we can't optimize this code
|
||||
if not has_numba and modules_used.issubset(NUMBA_REQUIRED_MODULES): # noqa : SIM103
|
||||
return False
|
||||
|
||||
return True
|
||||
return not (not has_numba and modules_used.issubset(NUMBA_REQUIRED_MODULES))
|
||||
|
||||
|
||||
def get_opt_review_metrics(
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ def filter_args(addopts_args: list[str]) -> list[str]:
|
|||
return filtered_args
|
||||
|
||||
|
||||
def modify_addopts(config_file: Path) -> tuple[str, bool]: # noqa : PLR0911
|
||||
def modify_addopts(config_file: Path) -> tuple[str, bool]:
|
||||
file_type = config_file.suffix.lower()
|
||||
filename = config_file.name
|
||||
config = None
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ def extract_test_context_from_env() -> tuple[str, str | None, str]:
|
|||
|
||||
def codeflash_behavior_async(func: F) -> F:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
|
|
@ -122,7 +122,7 @@ def codeflash_behavior_async(func: F) -> F:
|
|||
|
||||
def codeflash_performance_async(func: F) -> F:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
loop = asyncio.get_running_loop()
|
||||
function_name = func.__name__
|
||||
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
|
||||
|
|
@ -172,7 +172,7 @@ def codeflash_concurrency_async(func: F) -> F:
|
|||
"""Measures concurrent vs sequential execution performance for async functions."""
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
function_name = func.__name__
|
||||
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))
|
||||
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ EFFORT_VALUES: dict[str, dict[EffortLevel, Any]] = {
|
|||
}
|
||||
|
||||
|
||||
def get_effort_value(key: EffortKeys, effort: Union[EffortLevel, str]) -> Any: # noqa: ANN401
|
||||
def get_effort_value(key: EffortKeys, effort: Union[EffortLevel, str]) -> Any:
|
||||
key_str = key.value
|
||||
|
||||
if isinstance(effort, str):
|
||||
|
|
|
|||
|
|
@ -86,8 +86,7 @@ def find_conftest_files(test_paths: list[Path]) -> list[Path]:
|
|||
|
||||
# TODO for claude: There should be different functions to parse it per language, which should be chosen during runtime
|
||||
def parse_config_file(
|
||||
config_file_path: Path | None = None,
|
||||
override_formatter_check: bool = False, # noqa: FBT001, FBT002
|
||||
config_file_path: Path | None = None, override_formatter_check: bool = False
|
||||
) -> tuple[dict[str, Any], Path]:
|
||||
# First try package.json for JS/TS projects
|
||||
package_json_path = find_package_json(config_file_path)
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@ from codeflash.languages import current_language, is_python
|
|||
|
||||
def normalize_code(
|
||||
code: str,
|
||||
remove_docstrings: bool = True, # noqa: FBT001, FBT002
|
||||
return_ast_dump: bool = False, # noqa: FBT001, FBT002
|
||||
remove_docstrings: bool = True,
|
||||
return_ast_dump: bool = False,
|
||||
language: str | None = None,
|
||||
) -> str:
|
||||
"""Normalize code by parsing, cleaning, and normalizing variable names.
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, sav
|
|||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
|
||||
|
||||
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
|
||||
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool:
|
||||
if not formatter_cmds or formatter_cmds[0] == "disabled":
|
||||
return True
|
||||
first_cmd = formatter_cmds[0]
|
||||
|
|
|
|||
|
|
@ -40,11 +40,7 @@ def generate_unified_diff(original: str, modified: str, from_file: str, to_file:
|
|||
|
||||
|
||||
def apply_formatter_cmds(
|
||||
cmds: list[str],
|
||||
path: Path,
|
||||
test_dir_str: Optional[str],
|
||||
print_status: bool, # noqa
|
||||
exit_on_failure: bool = True, # noqa
|
||||
cmds: list[str], path: Path, test_dir_str: Optional[str], print_status: bool, exit_on_failure: bool = True
|
||||
) -> tuple[Path, str, bool]:
|
||||
if not path.exists():
|
||||
msg = f"File {path} does not exist. Cannot apply formatter commands."
|
||||
|
|
@ -111,9 +107,9 @@ def format_code(
|
|||
formatter_cmds: list[str],
|
||||
path: Union[str, Path],
|
||||
optimized_code: str = "",
|
||||
check_diff: bool = False, # noqa
|
||||
print_status: bool = True, # noqa
|
||||
exit_on_failure: bool = True, # noqa
|
||||
check_diff: bool = False,
|
||||
print_status: bool = True,
|
||||
exit_on_failure: bool = True,
|
||||
) -> str:
|
||||
if is_LSP_enabled():
|
||||
exit_on_failure = False
|
||||
|
|
@ -174,7 +170,7 @@ def format_code(
|
|||
return formatted_code
|
||||
|
||||
|
||||
def sort_imports(code: str, **kwargs: Any) -> str: # noqa : ANN401
|
||||
def sort_imports(code: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
# Deduplicate and sort imports, modify the code in memory, not on disk
|
||||
sorted_code = isort.code(code, **kwargs)
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class InjectPerfOnly(ast.NodeTransformer):
|
|||
# it's much more efficient to visit nodes manually. We'll only descend into expressions/statements.
|
||||
|
||||
# Helper for manual walk
|
||||
def iter_ast_calls(node): # noqa: ANN202, ANN001
|
||||
def iter_ast_calls(node): # noqa: ANN202
|
||||
# Generator to yield each ast.Call in test_node, preserves node identity
|
||||
stack = [node]
|
||||
while stack:
|
||||
|
|
@ -690,15 +690,14 @@ def detect_frameworks_from_code(code: str) -> dict[str, str]:
|
|||
frameworks["tensorflow"] = alias.asname if alias.asname else module_name
|
||||
elif module_name == "jax":
|
||||
frameworks["jax"] = alias.asname if alias.asname else module_name
|
||||
elif isinstance(node, ast.ImportFrom): # noqa: SIM102
|
||||
if node.module:
|
||||
module_name = node.module.split(".")[0]
|
||||
if module_name == "torch" and "torch" not in frameworks:
|
||||
frameworks["torch"] = module_name
|
||||
elif module_name == "tensorflow" and "tensorflow" not in frameworks:
|
||||
frameworks["tensorflow"] = module_name
|
||||
elif module_name == "jax" and "jax" not in frameworks:
|
||||
frameworks["jax"] = module_name
|
||||
elif isinstance(node, ast.ImportFrom) and node.module:
|
||||
module_name = node.module.split(".")[0]
|
||||
if module_name == "torch" and "torch" not in frameworks:
|
||||
frameworks["torch"] = module_name
|
||||
elif module_name == "tensorflow" and "tensorflow" not in frameworks:
|
||||
frameworks["tensorflow"] = module_name
|
||||
elif module_name == "jax" and "jax" not in frameworks:
|
||||
frameworks["jax"] = module_name
|
||||
|
||||
return frameworks
|
||||
|
||||
|
|
@ -910,8 +909,7 @@ def _create_device_sync_precompute_statements(used_frameworks: dict[str, str] |
|
|||
|
||||
|
||||
def _create_device_sync_statements(
|
||||
used_frameworks: dict[str, str] | None,
|
||||
for_return_value: bool = False, # noqa: FBT001, FBT002
|
||||
used_frameworks: dict[str, str] | None, for_return_value: bool = False
|
||||
) -> list[ast.stmt]:
|
||||
"""Create AST statements for device synchronization using pre-computed conditions.
|
||||
|
||||
|
|
@ -1450,7 +1448,7 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
|
|||
# Track when we enter a class
|
||||
self.context_stack.append(node.name.value)
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
|
||||
# Pop the context when we leave a class
|
||||
self.context_stack.pop()
|
||||
return updated_node
|
||||
|
|
@ -1530,7 +1528,7 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
|
|||
if import_alias.name.value == decorator_name:
|
||||
self.has_import = True
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# If the import is already there, don't add it again
|
||||
if self.has_import:
|
||||
return updated_node
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ class LineProfilerDecoratorAdder(cst.CSTTransformer):
|
|||
# Track when we enter a class
|
||||
self.context_stack.append(node.name.value)
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
|
||||
# Pop the context when we leave a class
|
||||
self.context_stack.pop()
|
||||
return updated_node
|
||||
|
|
@ -268,7 +268,7 @@ class ProfileEnableTransformer(cst.CSTTransformer):
|
|||
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
if not self.found_import:
|
||||
return updated_node
|
||||
|
||||
|
|
@ -332,11 +332,11 @@ def add_profile_enable(original_code: str, line_profile_output_file: str) -> str
|
|||
|
||||
|
||||
class ImportAdder(cst.CSTTransformer):
|
||||
def __init__(self, import_statement) -> None: # noqa: ANN001
|
||||
def __init__(self, import_statement) -> None:
|
||||
self.import_statement = import_statement
|
||||
self.has_import = False
|
||||
|
||||
def leave_Module(self, original_node, updated_node): # noqa: ANN001, ANN201, ARG002
|
||||
def leave_Module(self, original_node, updated_node): # noqa: ANN201
|
||||
# If the import is already there, don't add it again
|
||||
if self.has_import:
|
||||
return updated_node
|
||||
|
|
@ -347,7 +347,7 @@ class ImportAdder(cst.CSTTransformer):
|
|||
# Add the import to the module's body
|
||||
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])
|
||||
|
||||
def visit_ImportFrom(self, node) -> None: # noqa: ANN001
|
||||
def visit_ImportFrom(self, node) -> None:
|
||||
# Check if the profile is already imported from line_profiler
|
||||
if node.module and node.module.value == "line_profiler":
|
||||
for import_alias in node.names:
|
||||
|
|
|
|||
|
|
@ -702,7 +702,7 @@ def _wait_for_manual_code_input(oauth: OAuthHandler) -> None:
|
|||
if not oauth.is_complete:
|
||||
oauth.manual_code = code.strip()
|
||||
oauth.is_complete = True
|
||||
except Exception: # noqa: S110
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -242,9 +242,9 @@ def get_cross_platform_subprocess_run_args(
|
|||
cwd: Path | str | None = None,
|
||||
env: Mapping[str, str] | None = None,
|
||||
timeout: Optional[float] = None,
|
||||
check: bool = False, # noqa: FBT001, FBT002
|
||||
text: bool = True, # noqa: FBT001, FBT002
|
||||
capture_output: bool = True, # noqa: FBT001, FBT002 (only for non-Windows)
|
||||
check: bool = False,
|
||||
text: bool = True,
|
||||
capture_output: bool = True,
|
||||
) -> dict[str, str]:
|
||||
run_args = {"cwd": cwd, "env": env, "text": text, "timeout": timeout, "check": check}
|
||||
if sys.platform == "win32":
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def get_latest_version_from_pypi() -> str | None:
|
|||
|
||||
return latest_version
|
||||
logger.debug(f"Failed to fetch version from PyPI: {response.status_code}")
|
||||
return None # noqa: TRY300
|
||||
return None
|
||||
except requests.RequestException as e:
|
||||
logger.debug(f"Network error fetching version from PyPI: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -363,7 +363,7 @@ def extract_code_markdown_context_from_files(
|
|||
helpers_of_fto: dict[Path, set[FunctionSource]],
|
||||
helpers_of_helpers: dict[Path, set[FunctionSource]],
|
||||
project_root_path: Path,
|
||||
remove_docstrings: bool = False, # noqa: FBT001, FBT002
|
||||
remove_docstrings: bool = False,
|
||||
code_context_type: CodeContextType = CodeContextType.READ_ONLY,
|
||||
) -> CodeStringsMarkdown:
|
||||
"""Extract code context from files containing target functions and their helpers, formatting them as markdown.
|
||||
|
|
@ -1008,7 +1008,7 @@ def parse_code_and_prune_cst(
|
|||
code_context_type: CodeContextType,
|
||||
target_functions: set[str],
|
||||
helpers_of_helper_functions: set[str] = set(), # noqa: B006
|
||||
remove_docstrings: bool = False, # noqa: FBT001, FBT002
|
||||
remove_docstrings: bool = False,
|
||||
) -> str:
|
||||
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
|
||||
module = cst.parse_module(code)
|
||||
|
|
@ -1049,7 +1049,7 @@ def parse_code_and_prune_cst(
|
|||
return ""
|
||||
|
||||
|
||||
def prune_cst_for_read_writable_code( # noqa: PLR0911
|
||||
def prune_cst_for_read_writable_code(
|
||||
node: cst.CSTNode, target_functions: set[str], defs_with_usages: dict[str, UsageInfo], prefix: str = ""
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
|
||||
|
|
@ -1167,7 +1167,7 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
|
|||
return (node.with_changes(**updates) if updates else node), True
|
||||
|
||||
|
||||
def prune_cst_for_code_hashing( # noqa: PLR0911
|
||||
def prune_cst_for_code_hashing(
|
||||
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
|
||||
|
|
@ -1256,14 +1256,14 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
|
|||
return (node.with_changes(**updates) if updates else node), True
|
||||
|
||||
|
||||
def prune_cst_for_context( # noqa: PLR0911
|
||||
def prune_cst_for_context(
|
||||
node: cst.CSTNode,
|
||||
target_functions: set[str],
|
||||
helpers_of_helper_functions: set[str],
|
||||
prefix: str = "",
|
||||
remove_docstrings: bool = False, # noqa: FBT001, FBT002
|
||||
include_target_in_output: bool = False, # noqa: FBT001, FBT002
|
||||
include_init_dunder: bool = False, # noqa: FBT001, FBT002
|
||||
remove_docstrings: bool = False,
|
||||
include_target_in_output: bool = False,
|
||||
include_init_dunder: bool = False,
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Recursively filter the node for code context extraction.
|
||||
|
||||
|
|
|
|||
|
|
@ -209,7 +209,7 @@ class DependencyCollector(cst.CSTVisitor):
|
|||
self._extract_names_from_annotation(node.value)
|
||||
# No need to check the attribute name itself as it's likely not a top-level definition
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
|
||||
self.function_depth -= 1
|
||||
|
||||
if self.function_depth == 0 and self.class_depth == 0:
|
||||
|
|
@ -238,7 +238,7 @@ class DependencyCollector(cst.CSTVisitor):
|
|||
|
||||
self.class_depth += 1
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
|
||||
self.class_depth -= 1
|
||||
|
||||
if self.class_depth == 0:
|
||||
|
|
@ -261,7 +261,7 @@ class DependencyCollector(cst.CSTVisitor):
|
|||
# Use the first tracked name as the current top-level name (for dependency tracking)
|
||||
self.current_top_level_name = tracked_names[0]
|
||||
|
||||
def leave_Assign(self, original_node: cst.Assign) -> None: # noqa: ARG002
|
||||
def leave_Assign(self, original_node: cst.Assign) -> None:
|
||||
if self.processing_variable:
|
||||
self.processing_variable = False
|
||||
self.current_variable_names.clear()
|
||||
|
|
@ -371,7 +371,7 @@ class QualifiedFunctionUsageMarker:
|
|||
self.mark_as_used_recursively(dep)
|
||||
|
||||
|
||||
def remove_unused_definitions_recursively( # noqa: PLR0911
|
||||
def remove_unused_definitions_recursively(
|
||||
node: cst.CSTNode, definitions: dict[str, UsageInfo]
|
||||
) -> tuple[cst.CSTNode | None, bool]:
|
||||
"""Recursively filter the node to remove unused definitions.
|
||||
|
|
@ -554,7 +554,7 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
|
|||
# Apply the recursive removal transformation
|
||||
modified_module, _ = remove_unused_definitions_recursively(module, defs_with_usages)
|
||||
|
||||
return modified_module.code if modified_module else "" # noqa: TRY300
|
||||
return modified_module.code if modified_module else ""
|
||||
except Exception as e:
|
||||
# If any other error occurs during processing, return the original code
|
||||
logger.debug(f"Error processing code to remove unused definitions: {type(e).__name__}: {e}")
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class ReturnStatementVisitor(cst.CSTVisitor):
|
|||
super().__init__()
|
||||
self.has_return_statement: bool = False
|
||||
|
||||
def visit_Return(self, node: cst.Return) -> None: # noqa: ARG002
|
||||
def visit_Return(self, node: cst.Return) -> None:
|
||||
self.has_return_statement = True
|
||||
|
||||
|
||||
|
|
@ -352,7 +352,7 @@ def get_functions_to_optimize(
|
|||
return filtered_modified_functions, functions_count, trace_file_path
|
||||
|
||||
|
||||
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]: # noqa: FBT001
|
||||
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]:
|
||||
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes)
|
||||
return get_functions_within_lines(modified_lines)
|
||||
|
||||
|
|
@ -602,7 +602,7 @@ def get_all_replay_test_functions(
|
|||
def is_git_repo(file_path: str) -> bool:
|
||||
try:
|
||||
git.Repo(file_path, search_parent_directories=True)
|
||||
return True # noqa: TRY300
|
||||
return True
|
||||
except git.InvalidGitRepositoryError:
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -409,7 +409,7 @@ def provide_api_key(params: ProvideApiKeyParams) -> dict[str, str]:
|
|||
_init()
|
||||
if not is_successful(result):
|
||||
return {"status": "error", "message": result.failure()}
|
||||
return {"status": "success", "message": "Api key saved successfully", "user_id": user_id} # noqa: TRY300
|
||||
return {"status": "success", "message": "Api key saved successfully", "user_id": user_id}
|
||||
except Exception:
|
||||
return {"status": "error", "message": "something went wrong while saving the api key"}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def abort_if_cancelled(cancel_event: threading.Event) -> None:
|
|||
raise RuntimeError("cancelled")
|
||||
|
||||
|
||||
def sync_perform_optimization(server: CodeflashLanguageServer, cancel_event: threading.Event, params) -> dict[str, str]: # noqa
|
||||
def sync_perform_optimization(server: CodeflashLanguageServer, cancel_event: threading.Event, params) -> dict[str, str]:
|
||||
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
|
||||
should_run_experiment, code_context, original_helper_code = server.current_optimization_init_result
|
||||
function_optimizer = server.optimizer.current_function_optimizer
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ def report_to_markdown_table(report: dict[TestType, dict[str, int]], title: str)
|
|||
return table
|
||||
|
||||
|
||||
def simplify_worktree_paths(msg: str, highlight: bool = True) -> str: # noqa: FBT001, FBT002
|
||||
def simplify_worktree_paths(msg: str, highlight: bool = True) -> str:
|
||||
path_in_msg = worktree_path_regex.search(msg)
|
||||
if path_in_msg:
|
||||
# Use Path.name to handle both Unix and Windows path separators
|
||||
|
|
|
|||
|
|
@ -85,11 +85,7 @@ supported_lsp_log_levels = ("info", "debug")
|
|||
|
||||
|
||||
def enhanced_log(
|
||||
msg: str | Any, # noqa: ANN401
|
||||
actual_log_fn: Callable[[str, Any, Any], None],
|
||||
level: str,
|
||||
*args: Any, # noqa: ANN401
|
||||
**kwargs: Any, # noqa: ANN401
|
||||
msg: str | Any, actual_log_fn: Callable[[str, Any, Any], None], level: str, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
if not isinstance(msg, str):
|
||||
actual_log_fn(msg, *args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class LspMessage:
|
|||
takes_time: bool = False
|
||||
message_id: Optional[str] = None
|
||||
|
||||
def _loop_through(self, obj: Any) -> Any: # noqa: ANN401
|
||||
def _loop_through(self, obj: Any) -> Any:
|
||||
if isinstance(obj, list):
|
||||
return [self._loop_through(i) for i in obj]
|
||||
if isinstance(obj, dict):
|
||||
|
|
|
|||
|
|
@ -30,21 +30,21 @@ def main() -> None:
|
|||
if args.config_file and Path.exists(args.config_file):
|
||||
pyproject_config, _ = parse_config_file(args.config_file)
|
||||
disable_telemetry = pyproject_config.get("disable_telemetry", False)
|
||||
init_sentry(not disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(not disable_telemetry)
|
||||
init_sentry(enabled=not disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(enabled=not disable_telemetry)
|
||||
args.func()
|
||||
elif args.verify_setup:
|
||||
args = process_pyproject_config(args)
|
||||
init_sentry(not args.disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(not args.disable_telemetry)
|
||||
init_sentry(enabled=not args.disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
|
||||
ask_run_end_to_end_test(args)
|
||||
else:
|
||||
args = process_pyproject_config(args)
|
||||
if not env_utils.check_formatter_installed(args.formatter_cmds):
|
||||
return
|
||||
args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args)
|
||||
init_sentry(not args.disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(not args.disable_telemetry)
|
||||
init_sentry(enabled=not args.disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
|
||||
|
||||
from codeflash.optimization import optimizer
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,6 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b
|
|||
return False
|
||||
if (name := name.parent()) and name.type == "function":
|
||||
return get_qualified_name(name.module_name, name.full_name) == qualified_function_name
|
||||
return False # noqa: TRY300
|
||||
return False
|
||||
except ValueError:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -2704,8 +2704,6 @@ class FunctionOptimizer:
|
|||
pytest_cmd=self.test_cfg.pytest_cmd,
|
||||
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||
pytest_target_runtime_seconds=testing_time,
|
||||
pytest_min_loops=1,
|
||||
pytest_max_loops=1,
|
||||
test_framework=self.test_cfg.test_framework,
|
||||
js_project_root=self.test_cfg.js_project_root,
|
||||
line_profiler_output_file=line_profiler_output_file,
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ class PicklePatcher:
|
|||
obj: object,
|
||||
path: list[str] | None = None, # noqa: ARG004
|
||||
protocol: int | None = None,
|
||||
**kwargs: Any, # noqa: ANN401
|
||||
**kwargs: Any,
|
||||
) -> tuple[bool, bytes | str]:
|
||||
"""Try to pickle an object using pickle first, then dill. If both fail, create a placeholder.
|
||||
|
||||
|
|
@ -123,7 +123,7 @@ class PicklePatcher:
|
|||
return False, str(e)
|
||||
|
||||
@staticmethod
|
||||
def _recursive_pickle( # noqa: PLR0911
|
||||
def _recursive_pickle(
|
||||
obj: object,
|
||||
max_depth: int,
|
||||
path: list[str] | None = None,
|
||||
|
|
@ -192,7 +192,7 @@ class PicklePatcher:
|
|||
error_msg: str, # noqa: ARG004
|
||||
path: list[str],
|
||||
protocol: int | None = None,
|
||||
**kwargs: Any, # noqa: ANN401
|
||||
**kwargs: Any,
|
||||
) -> bytes:
|
||||
"""Handle pickling for dictionary objects.
|
||||
|
||||
|
|
@ -258,7 +258,7 @@ class PicklePatcher:
|
|||
error_msg: str, # noqa: ARG004
|
||||
path: list[str],
|
||||
protocol: int | None = None,
|
||||
**kwargs: Any, # noqa: ANN401
|
||||
**kwargs: Any,
|
||||
) -> bytes:
|
||||
"""Handle pickling for sequence types (list, tuple, set).
|
||||
|
||||
|
|
@ -311,12 +311,7 @@ class PicklePatcher:
|
|||
|
||||
@staticmethod
|
||||
def _handle_object(
|
||||
obj: object,
|
||||
max_depth: int,
|
||||
error_msg: str,
|
||||
path: list[str],
|
||||
protocol: int | None = None,
|
||||
**kwargs: Any, # noqa: ANN401
|
||||
obj: object, max_depth: int, error_msg: str, path: list[str], protocol: int | None = None, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Handle pickling for custom objects with __dict__.
|
||||
|
||||
|
|
@ -366,7 +361,7 @@ class PicklePatcher:
|
|||
if success:
|
||||
return result
|
||||
# Fall through to placeholder creation
|
||||
except Exception: # noqa: S110
|
||||
except Exception:
|
||||
pass # Fall through to placeholder creation
|
||||
|
||||
# If we get here, just use a placeholder
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class PicklePlaceholder:
|
|||
self.__dict__["error_msg"] = error_msg
|
||||
self.__dict__["path"] = path if path is not None else []
|
||||
|
||||
def __getattr__(self, name) -> Any: # noqa: ANN001, ANN401
|
||||
def __getattr__(self, name) -> Any:
|
||||
"""Raise a custom error when any attribute is accessed."""
|
||||
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
|
||||
msg = (
|
||||
|
|
@ -40,11 +40,11 @@ class PicklePlaceholder:
|
|||
)
|
||||
raise PicklePlaceholderAccessError(msg)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None: # noqa: ANN401
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
"""Prevent setting attributes."""
|
||||
self.__getattr__(name) # This will raise our custom error
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401, ARG002
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Raise a custom error when the object is called."""
|
||||
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
|
||||
msg = (
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from codeflash.version import __version__
|
|||
_posthog = None
|
||||
|
||||
|
||||
def initialize_posthog(enabled: bool = True) -> None: # noqa: FBT001, FBT002
|
||||
def initialize_posthog(*, enabled: bool = True) -> None:
|
||||
"""Enable or disable PostHog.
|
||||
|
||||
:param enabled: Whether to enable PostHog.
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import sentry_sdk
|
|||
from sentry_sdk.integrations.logging import LoggingIntegration
|
||||
|
||||
|
||||
def init_sentry(enabled: bool = False, exclude_errors: bool = False) -> None: # noqa: FBT001, FBT002
|
||||
def init_sentry(*, enabled: bool = False, exclude_errors: bool = False) -> None:
|
||||
if enabled:
|
||||
sentry_logging = LoggingIntegration(
|
||||
level=logging.INFO, # Capture info and above as breadcrumbs
|
||||
|
|
|
|||
|
|
@ -232,8 +232,8 @@ def main(args: Namespace | None = None) -> ArgumentParser:
|
|||
|
||||
args = process_pyproject_config(args)
|
||||
args.previous_checkpoint_functions = None
|
||||
init_sentry(not args.disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(not args.disable_telemetry)
|
||||
init_sentry(enabled=not args.disable_telemetry, exclude_errors=True)
|
||||
posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
|
||||
|
||||
from codeflash.optimization import optimizer
|
||||
|
||||
|
|
|
|||
|
|
@ -76,7 +76,8 @@ class Tracer:
|
|||
config: dict,
|
||||
result_pickle_file_path: Path,
|
||||
functions: list[str] | None = None,
|
||||
disable: bool = False, # noqa: FBT001, FBT002
|
||||
*,
|
||||
disable: bool = False,
|
||||
project_root: Path | None = None,
|
||||
max_function_count: int = 256,
|
||||
timeout: int | None = None, # seconds
|
||||
|
|
@ -309,7 +310,7 @@ class Tracer:
|
|||
with self.result_pickle_file_path.open("wb") as file:
|
||||
pickle.dump(pickle_data, file)
|
||||
|
||||
def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
|
||||
def tracer_logic(self, frame: FrameType, event: str) -> None:
|
||||
if event != "call":
|
||||
return
|
||||
if None is not self.timeout and (time.time() - self.start_time) > self.timeout:
|
||||
|
|
@ -494,7 +495,7 @@ class Tracer:
|
|||
class_name = arguments["self"].__class__.__name__
|
||||
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
|
||||
class_name = arguments["cls"].__name__
|
||||
except Exception: # noqa: S110
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name)
|
||||
|
|
@ -505,7 +506,7 @@ class Tracer:
|
|||
timings[fn] = cc, ns + 1, tt, ct, callers
|
||||
else:
|
||||
timings[fn] = 0, 0, 0, 0, {}
|
||||
return 1 # noqa: TRY300
|
||||
return 1
|
||||
except Exception:
|
||||
# Handle any errors gracefully
|
||||
return 0
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def path_belongs_to_site_packages(file_path: Path) -> bool:
|
|||
def is_git_repo(file_path: str) -> bool:
|
||||
try:
|
||||
git.Repo(file_path, search_parent_directories=True)
|
||||
return True # noqa: TRY300
|
||||
return True
|
||||
except git.InvalidGitRepositoryError:
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -99,7 +99,7 @@ def get_test_info_from_stack(tests_root: str) -> tuple[str, str | None, str, str
|
|||
return test_module_name, test_class_name, test_name, line_id
|
||||
|
||||
|
||||
def codeflash_capture(function_name: str, tmp_dir_path: str, tests_root: str, is_fto: bool = False) -> Callable: # noqa: FBT001, FBT002
|
||||
def codeflash_capture(function_name: str, tmp_dir_path: str, tests_root: str, is_fto: bool = False) -> Callable:
|
||||
"""Define a decorator to instrument the init function, collect test info, and capture the instance state."""
|
||||
|
||||
def decorator(wrapped: Callable) -> Callable:
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]: # no
|
|||
return _extract_exception_from_message(str(exc))
|
||||
|
||||
|
||||
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
|
||||
def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
|
||||
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
|
||||
try:
|
||||
# Handle exceptions specially - before type check to allow wrapper comparison
|
||||
|
|
@ -118,7 +118,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
|
||||
# Check if new wraps something that matches orig
|
||||
wrapped_new = _get_wrapped_exception(new)
|
||||
if wrapped_new is not None and comparator(orig, wrapped_new, superset_obj): # noqa: SIM103
|
||||
if wrapped_new is not None and comparator(orig, wrapped_new, superset_obj):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
|
@ -242,7 +242,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
continue
|
||||
if key not in new_keys or not comparator(orig_keys[key], new_keys[key], superset_obj):
|
||||
return False
|
||||
return True # noqa: TRY300
|
||||
return True
|
||||
|
||||
except sqlalchemy.exc.NoInspectionAvailable:
|
||||
pass
|
||||
|
|
@ -366,12 +366,12 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
try:
|
||||
if HAS_NUMPY and np.isnan(orig):
|
||||
return np.isnan(new)
|
||||
except Exception: # noqa: S110
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if HAS_NUMPY and np.isinf(orig):
|
||||
return np.isinf(new)
|
||||
except Exception: # noqa: S110
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if HAS_TORCH:
|
||||
|
|
@ -480,7 +480,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
try:
|
||||
if hasattr(orig, "__eq__") and str(type(orig.__eq__)) == "<class 'method'>":
|
||||
return orig == new
|
||||
except Exception: # noqa: S110
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# For class objects
|
||||
|
|
@ -512,7 +512,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
|
|||
# TODO : Add other types here
|
||||
logger.warning(f"Unknown comparator input type: {type(orig)}")
|
||||
sentry_sdk.capture_exception(RuntimeError(f"Unknown comparator input type: {type(orig)}"))
|
||||
return False # noqa: TRY300
|
||||
return False
|
||||
except RecursionError as e:
|
||||
logger.error(f"RecursionError while comparing objects: {e}")
|
||||
sentry_sdk.capture_exception(e)
|
||||
|
|
|
|||
|
|
@ -54,16 +54,11 @@ def instrument_codeflash_capture(
|
|||
|
||||
|
||||
def add_codeflash_capture_to_init(
|
||||
target_classes: set[str],
|
||||
fto_name: str,
|
||||
tmp_dir_path: str,
|
||||
code: str,
|
||||
tests_root: Path,
|
||||
is_fto: bool = False, # noqa: FBT001, FBT002
|
||||
target_classes: set[str], fto_name: str, tmp_dir_path: str, code: str, tests_root: Path, *, is_fto: bool = False
|
||||
) -> str:
|
||||
"""Add codeflash_capture decorator to __init__ function in the specified class."""
|
||||
tree = ast.parse(code)
|
||||
transformer = InitDecorator(target_classes, fto_name, tmp_dir_path, tests_root, is_fto)
|
||||
transformer = InitDecorator(target_classes, fto_name, tmp_dir_path, tests_root, is_fto=is_fto)
|
||||
modified_tree = transformer.visit(tree)
|
||||
if transformer.inserted_decorator:
|
||||
ast.fix_missing_locations(modified_tree)
|
||||
|
|
@ -76,12 +71,7 @@ class InitDecorator(ast.NodeTransformer):
|
|||
"""AST transformer that adds codeflash_capture decorator to specific class's __init__."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_classes: set[str],
|
||||
fto_name: str,
|
||||
tmp_dir_path: str,
|
||||
tests_root: Path,
|
||||
is_fto=False, # noqa: ANN001, FBT002
|
||||
self, target_classes: set[str], fto_name: str, tmp_dir_path: str, tests_root: Path, *, is_fto: bool = False
|
||||
) -> None:
|
||||
self.target_classes = target_classes
|
||||
self.fto_name = fto_name
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import sys
|
|||
import time as _time_module
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
from unittest import TestCase
|
||||
|
||||
# PyTest Imports
|
||||
|
|
@ -437,7 +437,7 @@ class PytestLoops:
|
|||
"importlib",
|
||||
}
|
||||
|
||||
def _clear_cache_for_object(obj: Any) -> None: # noqa: ANN401
|
||||
def _clear_cache_for_object(obj: obj) -> None:
|
||||
if obj in processed_functions:
|
||||
return
|
||||
processed_functions.add(obj)
|
||||
|
|
@ -469,9 +469,9 @@ class PytestLoops:
|
|||
for _, obj in inspect.getmembers(module):
|
||||
if callable(obj):
|
||||
_clear_cache_for_object(obj)
|
||||
except Exception: # noqa: S110
|
||||
except Exception:
|
||||
pass
|
||||
except Exception: # noqa: S110
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _set_nodeid(self, nodeid: str, count: int) -> str:
|
||||
|
|
@ -581,7 +581,7 @@ class PytestLoops:
|
|||
os.environ["CODEFLASH_TEST_FUNCTION"] = test_function_name
|
||||
|
||||
@pytest.hookimpl(trylast=True)
|
||||
def pytest_runtest_teardown(self, item: pytest.Item) -> None: # noqa: ARG002
|
||||
def pytest_runtest_teardown(self, item: pytest.Item) -> None:
|
||||
"""Clean up test context environment variables after each test."""
|
||||
for var in ["CODEFLASH_TEST_MODULE", "CODEFLASH_TEST_CLASS", "CODEFLASH_TEST_FUNCTION"]:
|
||||
os.environ.pop(var, None)
|
||||
|
|
|
|||
|
|
@ -258,8 +258,8 @@ def run_line_profile_tests(
|
|||
*,
|
||||
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME_EFFECTIVE,
|
||||
pytest_timeout: int | None = None,
|
||||
pytest_min_loops: int = 5, # noqa: ARG001
|
||||
pytest_max_loops: int = 100_000, # noqa: ARG001
|
||||
pytest_min_loops: int = 5,
|
||||
pytest_max_loops: int = 100_000,
|
||||
js_project_root: Path | None = None,
|
||||
line_profiler_output_file: Path | None = None,
|
||||
) -> tuple[Path, subprocess.CompletedProcess]:
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def generate_tests(
|
|||
test_index: int,
|
||||
test_path: Path,
|
||||
test_perf_path: Path,
|
||||
is_numerical_code: bool | None = None, # noqa: FBT001
|
||||
is_numerical_code: bool | None = None,
|
||||
) -> tuple[str, str, str, Path, Path] | None:
|
||||
# TODO: Sometimes this recreates the original Class definition. This overrides and messes up the original
|
||||
# class import. Remove the recreation of the class definition
|
||||
|
|
|
|||
|
|
@ -263,12 +263,13 @@ ignore = [
|
|||
"TD007",
|
||||
"D417",
|
||||
"D401",
|
||||
"S110", # try-except-pass - we do this a lot
|
||||
"ARG002", # Unused method argument
|
||||
# Added for multi-language branch
|
||||
"FBT001", # Boolean positional argument
|
||||
"FBT002", # Boolean default positional argument
|
||||
"ANN401", # typing.Any disallowed
|
||||
"ARG001", # Unused function argument (common in abstract/interface methods)
|
||||
"ARG002", # Unused method argument
|
||||
"TRY300", # Consider moving to else block
|
||||
"TRY401", # Redundant exception in logging.exception
|
||||
"PLR0911", # Too many return statements
|
||||
|
|
@ -277,7 +278,6 @@ ignore = [
|
|||
"SIM102", # Nested if statements
|
||||
"SIM103", # Return negated condition
|
||||
"ANN001", # Missing type annotation
|
||||
"S110", # try-except-pass
|
||||
"PLC0206", # Dictionary items
|
||||
"S314", # XML parsing (acceptable for dev tool)
|
||||
"S608", # SQL injection (internal use only)
|
||||
|
|
|
|||
Loading…
Reference in a new issue