Merge branch 'main' into multi-language

This commit is contained in:
Kevin Turcios 2026-01-29 13:05:22 -05:00
commit 84d0b1cf09
45 changed files with 156 additions and 194 deletions

View file

@ -298,7 +298,7 @@ class AiServiceClient:
line_profiler_results: str,
n_candidates: int,
experiment_metadata: ExperimentMetadata | None = None,
is_numerical_code: bool | None = None, # noqa: FBT001
is_numerical_code: bool | None = None,
language: str = "python",
language_version: str | None = None,
) -> list[OptimizedCandidate]:
@ -814,7 +814,7 @@ class AiServiceClient:
error = response.json()["error"]
logger.error(f"Error generating tests: {response.status_code} - {error}")
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": error})
return None # noqa: TRY300
return None
except Exception:
logger.error(f"Error generating tests: {response.status_code} - {response.text}")
ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text})
@ -830,7 +830,6 @@ class AiServiceClient:
function_trace_id: str,
coverage_message: str,
replay_tests: str,
concolic_tests: str, # noqa: ARG002
calling_fn_details: str,
language: str = "python",
) -> OptimizationReviewResult:

View file

@ -81,7 +81,7 @@ def make_cfapi_request(
else:
response = requests.get(url, headers=cfapi_headers, params=params, timeout=60)
response.raise_for_status()
return response # noqa: TRY300
return response
except requests.exceptions.HTTPError:
# response may be either a string or JSON, so we handle both cases
error_message = ""
@ -102,7 +102,7 @@ def make_cfapi_request(
@lru_cache(maxsize=1)
def get_user_id(api_key: Optional[str] = None) -> Optional[str]: # noqa: PLR0911
def get_user_id(api_key: Optional[str] = None) -> Optional[str]:
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
:param api_key: The API key to use. If None, uses get_codeflash_api_key().
@ -396,7 +396,7 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]:
def is_function_being_optimized_again(
owner: str, repo: str, pr_number: int, code_contexts: list[dict[str, str]]
) -> Any: # noqa: ANN401
) -> Any:
"""Check if the function being optimized is being optimized again."""
response = make_cfapi_request(
"/is-already-optimized",

View file

@ -108,7 +108,7 @@ class CodeflashTrace:
func_id = (func.__module__, func.__name__)
@functools.wraps(func)
def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003, ANN401
def wrapper(*args, **kwargs) -> Any: # noqa: ANN002, ANN003
# Initialize thread-local active functions set if it doesn't exist
if not hasattr(self._thread_local, "active_functions"):
self._thread_local.active_functions = set()

View file

@ -53,7 +53,7 @@ class AddDecoratorTransformer(cst.CSTTransformer):
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# Create import statement for codeflash_trace
if not self.added_codeflash_trace:
return updated_node

View file

@ -200,7 +200,7 @@ class CodeFlashBenchmarkPlugin:
# Pytest hooks
@pytest.hookimpl
def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001, ARG002
def pytest_sessionfinish(self, session, exitstatus) -> None: # noqa: ANN001
"""Execute after whole test run is completed."""
# Write any remaining benchmark timings to the database
codeflash_trace.close()
@ -218,7 +218,7 @@ class CodeFlashBenchmarkPlugin:
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
for item in items:
# Check for direct benchmark fixture usage
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames # ty:ignore[unsupported-operator]
# Check for @pytest.mark.benchmark marker
has_marker = False
@ -249,7 +249,7 @@ class CodeFlashBenchmarkPlugin:
self._run_benchmark(func)
return wrapped_func
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003, ANN202
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN002, ANN003, ANN202
"""Actual benchmark implementation."""
node_path = getattr(self.request.node, "path", None) or getattr(self.request.node, "fspath", None)
if node_path is None:

View file

@ -43,7 +43,7 @@ def inquirer_wrapper(func: Callable[..., str | bool], *args: str | bool, **kwarg
return func(*new_args, **new_kwargs)
def split_string_to_cli_width(string: str, is_confirm: bool = False) -> list[str]: # noqa: FBT001, FBT002
def split_string_to_cli_width(string: str, is_confirm: bool = False) -> list[str]:
cli_width, _ = shutil.get_terminal_size()
# split string to lines that accommodate "[?] " prefix
cli_width -= len("[?] ")

View file

@ -687,7 +687,7 @@ def create_empty_pyproject_toml(pyproject_toml_path: Path) -> None:
apologize_and_exit()
def install_github_actions(override_formatter_check: bool = False) -> None: # noqa: FBT001, FBT002
def install_github_actions(override_formatter_check: bool = False) -> None:
try:
config, _config_file_path = parse_config_file(override_formatter_check=override_formatter_check)
@ -1121,11 +1121,12 @@ def install_github_actions(override_formatter_check: bool = False) -> None: # n
apologize_and_exit()
def determine_dependency_manager(pyproject_data: dict[str, Any]) -> DependencyManager: # noqa: PLR0911
def determine_dependency_manager(pyproject_data: dict[str, Any]) -> DependencyManager:
"""Determine which dependency manager is being used based on pyproject.toml contents."""
if (Path.cwd() / "poetry.lock").exists():
cwd = Path.cwd()
if (cwd / "poetry.lock").exists():
return DependencyManager.POETRY
if (Path.cwd() / "uv.lock").exists():
if (cwd / "uv.lock").exists():
return DependencyManager.UV
if "tool" not in pyproject_data:
return DependencyManager.PIP
@ -1325,10 +1326,7 @@ def collect_repo_files_for_workflow(git_root: Path) -> dict[str, Any]:
def generate_dynamic_workflow_content(
optimize_yml_content: str,
config: tuple[dict[str, Any], Path],
git_root: Path,
benchmark_mode: bool = False, # noqa: FBT001, FBT002
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path, benchmark_mode: bool = False
) -> str:
"""Generate workflow content with dynamic steps from AI service, falling back to static template.
@ -1460,10 +1458,7 @@ def generate_dynamic_workflow_content(
def customize_codeflash_yaml_content(
optimize_yml_content: str,
config: tuple[dict[str, Any], Path],
git_root: Path,
benchmark_mode: bool = False, # noqa: FBT001, FBT002
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path, benchmark_mode: bool = False
) -> str:
module_path = str(Path(config["module_root"]).relative_to(git_root) / "**")
optimize_yml_content = optimize_yml_content.replace("{{ codeflash_module_path }}", module_path)

View file

@ -45,14 +45,14 @@ class GlobalFunctionCollector(cst.CSTVisitor):
self.scope_depth += 1
return True
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self.scope_depth -= 1
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
self.scope_depth += 1
return True
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self.scope_depth -= 1
@ -66,7 +66,7 @@ class GlobalFunctionTransformer(cst.CSTTransformer):
self.processed_functions: set[str] = set()
self.scope_depth = 0
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
self.scope_depth += 1
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
@ -81,14 +81,14 @@ class GlobalFunctionTransformer(cst.CSTTransformer):
return self.new_functions[name]
return updated_node
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.scope_depth += 1
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
self.scope_depth -= 1
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# Add any new functions that weren't in the original file
new_statements = list(updated_node.body)
@ -142,28 +142,28 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
self.scope_depth = 0
self.if_else_depth = 0
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: # noqa: ARG002
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
self.scope_depth += 1
return True
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self.scope_depth -= 1
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
self.scope_depth += 1
return True
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self.scope_depth -= 1
def visit_If(self, node: cst.If) -> Optional[bool]: # noqa: ARG002
def visit_If(self, node: cst.If) -> Optional[bool]:
self.if_else_depth += 1
return True
def leave_If(self, original_node: cst.If) -> None: # noqa: ARG002
def leave_If(self, original_node: cst.If) -> None:
self.if_else_depth -= 1
def visit_Else(self, node: cst.Else) -> Optional[bool]: # noqa: ARG002
def visit_Else(self, node: cst.Else) -> Optional[bool]:
# Else blocks are already counted as part of the if statement
return True
@ -232,24 +232,24 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
self.scope_depth = 0
self.if_else_depth = 0
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
self.scope_depth += 1
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
self.scope_depth -= 1
return updated_node
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.scope_depth += 1
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
self.scope_depth -= 1
return updated_node
def visit_If(self, node: cst.If) -> None: # noqa: ARG002
def visit_If(self, node: cst.If) -> None:
self.if_else_depth += 1
def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If: # noqa: ARG002
def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
self.if_else_depth -= 1
return updated_node
@ -284,7 +284,7 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# Add any new assignments that weren't in the original file
new_statements = list(updated_node.body)
@ -371,7 +371,7 @@ class GlobalStatementTransformer(cst.CSTTransformer):
super().__init__()
self.global_statements = global_statements
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
if not self.global_statements:
return updated_node
@ -397,20 +397,20 @@ class GlobalStatementCollector(cst.CSTVisitor):
self.global_statements = []
self.in_function_or_class = False
def visit_ClassDef(self, node: cst.ClassDef) -> bool: # noqa: ARG002
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
# Don't visit inside classes
self.in_function_or_class = True
return False
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self.in_function_or_class = False
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # noqa: ARG002
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
# Don't visit inside functions
self.in_function_or_class = True
return False
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self.in_function_or_class = False
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
@ -491,16 +491,16 @@ class DottedImportCollector(cst.CSTVisitor):
self.depth = 0
self._collect_imports_from_block(node)
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
self.depth += 1
def leave_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
self.depth -= 1
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.depth += 1
def leave_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
def leave_ClassDef(self, node: cst.ClassDef) -> None:
self.depth -= 1
def visit_If(self, node: cst.If) -> None:
@ -530,9 +530,7 @@ def find_last_import_line(target_code: str) -> int:
class FutureAliasedImportTransformer(cst.CSTTransformer):
def leave_ImportFrom(
self,
original_node: cst.ImportFrom, # noqa: ARG002
updated_node: cst.ImportFrom,
self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom
) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel:
import libcst.matchers as m
@ -677,7 +675,7 @@ def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
if not name.startswith("_"):
public_names.add(name)
return public_names # noqa: TRY300
return public_names
except Exception as e:
logger.warning(f"Error resolving star import for {module_name}: {e}")
@ -1166,7 +1164,7 @@ class FunctionCallFinder(ast.NodeVisitor):
return False
def _get_call_name(self, func_node) -> Optional[str]: # noqa: ANN001
def _get_call_name(self, func_node) -> Optional[str]:
"""Extract the name being called from a function node."""
# Fast path short-circuit for ast.Name nodes
if isinstance(func_node, ast.Name):
@ -1559,10 +1557,7 @@ def is_numerical_code(code_string: str, function_name: str | None = None) -> boo
# If numba is not installed and all modules used require numba for optimization,
# return False since we can't optimize this code
if not has_numba and modules_used.issubset(NUMBA_REQUIRED_MODULES): # noqa : SIM103
return False
return True
return not (not has_numba and modules_used.issubset(NUMBA_REQUIRED_MODULES))
def get_opt_review_metrics(

View file

@ -166,7 +166,7 @@ def filter_args(addopts_args: list[str]) -> list[str]:
return filtered_args
def modify_addopts(config_file: Path) -> tuple[str, bool]: # noqa : PLR0911
def modify_addopts(config_file: Path) -> tuple[str, bool]:
file_type = config_file.suffix.lower()
filename = config_file.name
config = None

View file

@ -46,7 +46,7 @@ def extract_test_context_from_env() -> tuple[str, str | None, str]:
def codeflash_behavior_async(func: F) -> F:
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
@ -122,7 +122,7 @@ def codeflash_behavior_async(func: F) -> F:
def codeflash_performance_async(func: F) -> F:
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
loop = asyncio.get_running_loop()
function_name = func.__name__
line_id = os.environ["CODEFLASH_CURRENT_LINE_ID"]
@ -172,7 +172,7 @@ def codeflash_concurrency_async(func: F) -> F:
"""Measures concurrent vs sequential execution performance for async functions."""
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
function_name = func.__name__
concurrency_factor = int(os.environ.get("CODEFLASH_CONCURRENCY_FACTOR", "10"))

View file

@ -91,7 +91,7 @@ EFFORT_VALUES: dict[str, dict[EffortLevel, Any]] = {
}
def get_effort_value(key: EffortKeys, effort: Union[EffortLevel, str]) -> Any: # noqa: ANN401
def get_effort_value(key: EffortKeys, effort: Union[EffortLevel, str]) -> Any:
key_str = key.value
if isinstance(effort, str):

View file

@ -86,8 +86,7 @@ def find_conftest_files(test_paths: list[Path]) -> list[Path]:
# TODO for claude: There should be different functions to parse it per language, which should be chosen during runtime
def parse_config_file(
config_file_path: Path | None = None,
override_formatter_check: bool = False, # noqa: FBT001, FBT002
config_file_path: Path | None = None, override_formatter_check: bool = False
) -> tuple[dict[str, Any], Path]:
# First try package.json for JS/TS projects
package_json_path = find_package_json(config_file_path)

View file

@ -15,8 +15,8 @@ from codeflash.languages import current_language, is_python
def normalize_code(
code: str,
remove_docstrings: bool = True, # noqa: FBT001, FBT002
return_ast_dump: bool = False, # noqa: FBT001, FBT002
remove_docstrings: bool = True,
return_ast_dump: bool = False,
language: str | None = None,
) -> str:
"""Normalize code by parsing, cleaning, and normalizing variable names.

View file

@ -16,7 +16,7 @@ from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, sav
from codeflash.lsp.helpers import is_LSP_enabled
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool:
if not formatter_cmds or formatter_cmds[0] == "disabled":
return True
first_cmd = formatter_cmds[0]

View file

@ -40,11 +40,7 @@ def generate_unified_diff(original: str, modified: str, from_file: str, to_file:
def apply_formatter_cmds(
cmds: list[str],
path: Path,
test_dir_str: Optional[str],
print_status: bool, # noqa
exit_on_failure: bool = True, # noqa
cmds: list[str], path: Path, test_dir_str: Optional[str], print_status: bool, exit_on_failure: bool = True
) -> tuple[Path, str, bool]:
if not path.exists():
msg = f"File {path} does not exist. Cannot apply formatter commands."
@ -111,9 +107,9 @@ def format_code(
formatter_cmds: list[str],
path: Union[str, Path],
optimized_code: str = "",
check_diff: bool = False, # noqa
print_status: bool = True, # noqa
exit_on_failure: bool = True, # noqa
check_diff: bool = False,
print_status: bool = True,
exit_on_failure: bool = True,
) -> str:
if is_LSP_enabled():
exit_on_failure = False
@ -174,7 +170,7 @@ def format_code(
return formatted_code
def sort_imports(code: str, **kwargs: Any) -> str: # noqa : ANN401
def sort_imports(code: str, **kwargs: Any) -> str:
try:
# Deduplicate and sort imports, modify the code in memory, not on disk
sorted_code = isort.code(code, **kwargs)

View file

@ -89,7 +89,7 @@ class InjectPerfOnly(ast.NodeTransformer):
# it's much more efficient to visit nodes manually. We'll only descend into expressions/statements.
# Helper for manual walk
def iter_ast_calls(node): # noqa: ANN202, ANN001
def iter_ast_calls(node): # noqa: ANN202
# Generator to yield each ast.Call in test_node, preserves node identity
stack = [node]
while stack:
@ -690,15 +690,14 @@ def detect_frameworks_from_code(code: str) -> dict[str, str]:
frameworks["tensorflow"] = alias.asname if alias.asname else module_name
elif module_name == "jax":
frameworks["jax"] = alias.asname if alias.asname else module_name
elif isinstance(node, ast.ImportFrom): # noqa: SIM102
if node.module:
module_name = node.module.split(".")[0]
if module_name == "torch" and "torch" not in frameworks:
frameworks["torch"] = module_name
elif module_name == "tensorflow" and "tensorflow" not in frameworks:
frameworks["tensorflow"] = module_name
elif module_name == "jax" and "jax" not in frameworks:
frameworks["jax"] = module_name
elif isinstance(node, ast.ImportFrom) and node.module:
module_name = node.module.split(".")[0]
if module_name == "torch" and "torch" not in frameworks:
frameworks["torch"] = module_name
elif module_name == "tensorflow" and "tensorflow" not in frameworks:
frameworks["tensorflow"] = module_name
elif module_name == "jax" and "jax" not in frameworks:
frameworks["jax"] = module_name
return frameworks
@ -910,8 +909,7 @@ def _create_device_sync_precompute_statements(used_frameworks: dict[str, str] |
def _create_device_sync_statements(
used_frameworks: dict[str, str] | None,
for_return_value: bool = False, # noqa: FBT001, FBT002
used_frameworks: dict[str, str] | None, for_return_value: bool = False
) -> list[ast.stmt]:
"""Create AST statements for device synchronization using pre-computed conditions.
@ -1450,7 +1448,7 @@ class AsyncDecoratorAdder(cst.CSTTransformer):
# Track when we enter a class
self.context_stack.append(node.name.value)
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
# Pop the context when we leave a class
self.context_stack.pop()
return updated_node
@ -1530,7 +1528,7 @@ class AsyncDecoratorImportAdder(cst.CSTTransformer):
if import_alias.name.value == decorator_name:
self.has_import = True
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# If the import is already there, don't add it again
if self.has_import:
return updated_node

View file

@ -204,7 +204,7 @@ class LineProfilerDecoratorAdder(cst.CSTTransformer):
# Track when we enter a class
self.context_stack.append(node.name.value)
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
# Pop the context when we leave a class
self.context_stack.pop()
return updated_node
@ -268,7 +268,7 @@ class ProfileEnableTransformer(cst.CSTTransformer):
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
if not self.found_import:
return updated_node
@ -332,11 +332,11 @@ def add_profile_enable(original_code: str, line_profile_output_file: str) -> str
class ImportAdder(cst.CSTTransformer):
def __init__(self, import_statement) -> None: # noqa: ANN001
def __init__(self, import_statement) -> None:
self.import_statement = import_statement
self.has_import = False
def leave_Module(self, original_node, updated_node): # noqa: ANN001, ANN201, ARG002
def leave_Module(self, original_node, updated_node): # noqa: ANN201
# If the import is already there, don't add it again
if self.has_import:
return updated_node
@ -347,7 +347,7 @@ class ImportAdder(cst.CSTTransformer):
# Add the import to the module's body
return updated_node.with_changes(body=[import_node, *list(updated_node.body)])
def visit_ImportFrom(self, node) -> None: # noqa: ANN001
def visit_ImportFrom(self, node) -> None:
# Check if the profile is already imported from line_profiler
if node.module and node.module.value == "line_profiler":
for import_alias in node.names:

View file

@ -702,7 +702,7 @@ def _wait_for_manual_code_input(oauth: OAuthHandler) -> None:
if not oauth.is_complete:
oauth.manual_code = code.strip()
oauth.is_complete = True
except Exception: # noqa: S110
except Exception:
pass

View file

@ -242,9 +242,9 @@ def get_cross_platform_subprocess_run_args(
cwd: Path | str | None = None,
env: Mapping[str, str] | None = None,
timeout: Optional[float] = None,
check: bool = False, # noqa: FBT001, FBT002
text: bool = True, # noqa: FBT001, FBT002
capture_output: bool = True, # noqa: FBT001, FBT002 (only for non-Windows)
check: bool = False,
text: bool = True,
capture_output: bool = True,
) -> dict[str, str]:
run_args = {"cwd": cwd, "env": env, "text": text, "timeout": timeout, "check": check}
if sys.platform == "win32":

View file

@ -39,7 +39,7 @@ def get_latest_version_from_pypi() -> str | None:
return latest_version
logger.debug(f"Failed to fetch version from PyPI: {response.status_code}")
return None # noqa: TRY300
return None
except requests.RequestException as e:
logger.debug(f"Network error fetching version from PyPI: {e}")
return None

View file

@ -363,7 +363,7 @@ def extract_code_markdown_context_from_files(
helpers_of_fto: dict[Path, set[FunctionSource]],
helpers_of_helpers: dict[Path, set[FunctionSource]],
project_root_path: Path,
remove_docstrings: bool = False, # noqa: FBT001, FBT002
remove_docstrings: bool = False,
code_context_type: CodeContextType = CodeContextType.READ_ONLY,
) -> CodeStringsMarkdown:
"""Extract code context from files containing target functions and their helpers, formatting them as markdown.
@ -1008,7 +1008,7 @@ def parse_code_and_prune_cst(
code_context_type: CodeContextType,
target_functions: set[str],
helpers_of_helper_functions: set[str] = set(), # noqa: B006
remove_docstrings: bool = False, # noqa: FBT001, FBT002
remove_docstrings: bool = False,
) -> str:
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
module = cst.parse_module(code)
@ -1049,7 +1049,7 @@ def parse_code_and_prune_cst(
return ""
def prune_cst_for_read_writable_code( # noqa: PLR0911
def prune_cst_for_read_writable_code(
node: cst.CSTNode, target_functions: set[str], defs_with_usages: dict[str, UsageInfo], prefix: str = ""
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
@ -1167,7 +1167,7 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
return (node.with_changes(**updates) if updates else node), True
def prune_cst_for_code_hashing( # noqa: PLR0911
def prune_cst_for_code_hashing(
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
@ -1256,14 +1256,14 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
return (node.with_changes(**updates) if updates else node), True
def prune_cst_for_context( # noqa: PLR0911
def prune_cst_for_context(
node: cst.CSTNode,
target_functions: set[str],
helpers_of_helper_functions: set[str],
prefix: str = "",
remove_docstrings: bool = False, # noqa: FBT001, FBT002
include_target_in_output: bool = False, # noqa: FBT001, FBT002
include_init_dunder: bool = False, # noqa: FBT001, FBT002
remove_docstrings: bool = False,
include_target_in_output: bool = False,
include_init_dunder: bool = False,
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node for code context extraction.

View file

@ -209,7 +209,7 @@ class DependencyCollector(cst.CSTVisitor):
self._extract_names_from_annotation(node.value)
# No need to check the attribute name itself as it's likely not a top-level definition
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self.function_depth -= 1
if self.function_depth == 0 and self.class_depth == 0:
@ -238,7 +238,7 @@ class DependencyCollector(cst.CSTVisitor):
self.class_depth += 1
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self.class_depth -= 1
if self.class_depth == 0:
@ -261,7 +261,7 @@ class DependencyCollector(cst.CSTVisitor):
# Use the first tracked name as the current top-level name (for dependency tracking)
self.current_top_level_name = tracked_names[0]
def leave_Assign(self, original_node: cst.Assign) -> None: # noqa: ARG002
def leave_Assign(self, original_node: cst.Assign) -> None:
if self.processing_variable:
self.processing_variable = False
self.current_variable_names.clear()
@ -371,7 +371,7 @@ class QualifiedFunctionUsageMarker:
self.mark_as_used_recursively(dep)
def remove_unused_definitions_recursively( # noqa: PLR0911
def remove_unused_definitions_recursively(
node: cst.CSTNode, definitions: dict[str, UsageInfo]
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node to remove unused definitions.
@ -554,7 +554,7 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
# Apply the recursive removal transformation
modified_module, _ = remove_unused_definitions_recursively(module, defs_with_usages)
return modified_module.code if modified_module else "" # noqa: TRY300
return modified_module.code if modified_module else ""
except Exception as e:
# If any other error occurs during processing, return the original code
logger.debug(f"Error processing code to remove unused definitions: {type(e).__name__}: {e}")

View file

@ -62,7 +62,7 @@ class ReturnStatementVisitor(cst.CSTVisitor):
super().__init__()
self.has_return_statement: bool = False
def visit_Return(self, node: cst.Return) -> None: # noqa: ARG002
def visit_Return(self, node: cst.Return) -> None:
self.has_return_statement = True
@ -352,7 +352,7 @@ def get_functions_to_optimize(
return filtered_modified_functions, functions_count, trace_file_path
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]: # noqa: FBT001
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]:
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes)
return get_functions_within_lines(modified_lines)
@ -602,7 +602,7 @@ def get_all_replay_test_functions(
def is_git_repo(file_path: str) -> bool:
try:
git.Repo(file_path, search_parent_directories=True)
return True # noqa: TRY300
return True
except git.InvalidGitRepositoryError:
return False

View file

@ -409,7 +409,7 @@ def provide_api_key(params: ProvideApiKeyParams) -> dict[str, str]:
_init()
if not is_successful(result):
return {"status": "error", "message": result.failure()}
return {"status": "success", "message": "Api key saved successfully", "user_id": user_id} # noqa: TRY300
return {"status": "success", "message": "Api key saved successfully", "user_id": user_id}
except Exception:
return {"status": "error", "message": "something went wrong while saving the api key"}

View file

@ -25,7 +25,7 @@ def abort_if_cancelled(cancel_event: threading.Event) -> None:
raise RuntimeError("cancelled")
def sync_perform_optimization(server: CodeflashLanguageServer, cancel_event: threading.Event, params) -> dict[str, str]: # noqa
def sync_perform_optimization(server: CodeflashLanguageServer, cancel_event: threading.Event, params) -> dict[str, str]:
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
should_run_experiment, code_context, original_helper_code = server.current_optimization_init_result
function_optimizer = server.optimizer.current_function_optimizer

View file

@ -46,7 +46,7 @@ def report_to_markdown_table(report: dict[TestType, dict[str, int]], title: str)
return table
def simplify_worktree_paths(msg: str, highlight: bool = True) -> str: # noqa: FBT001, FBT002
def simplify_worktree_paths(msg: str, highlight: bool = True) -> str:
path_in_msg = worktree_path_regex.search(msg)
if path_in_msg:
# Use Path.name to handle both Unix and Windows path separators

View file

@ -85,11 +85,7 @@ supported_lsp_log_levels = ("info", "debug")
def enhanced_log(
msg: str | Any, # noqa: ANN401
actual_log_fn: Callable[[str, Any, Any], None],
level: str,
*args: Any, # noqa: ANN401
**kwargs: Any, # noqa: ANN401
msg: str | Any, actual_log_fn: Callable[[str, Any, Any], None], level: str, *args: Any, **kwargs: Any
) -> None:
if not isinstance(msg, str):
actual_log_fn(msg, *args, **kwargs)

View file

@ -27,7 +27,7 @@ class LspMessage:
takes_time: bool = False
message_id: Optional[str] = None
def _loop_through(self, obj: Any) -> Any: # noqa: ANN401
def _loop_through(self, obj: Any) -> Any:
if isinstance(obj, list):
return [self._loop_through(i) for i in obj]
if isinstance(obj, dict):

View file

@ -30,21 +30,21 @@ def main() -> None:
if args.config_file and Path.exists(args.config_file):
pyproject_config, _ = parse_config_file(args.config_file)
disable_telemetry = pyproject_config.get("disable_telemetry", False)
init_sentry(not disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(not disable_telemetry)
init_sentry(enabled=not disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(enabled=not disable_telemetry)
args.func()
elif args.verify_setup:
args = process_pyproject_config(args)
init_sentry(not args.disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(not args.disable_telemetry)
init_sentry(enabled=not args.disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
ask_run_end_to_end_test(args)
else:
args = process_pyproject_config(args)
if not env_utils.check_formatter_installed(args.formatter_cmds):
return
args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args)
init_sentry(not args.disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(not args.disable_telemetry)
init_sentry(enabled=not args.disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
from codeflash.optimization import optimizer

View file

@ -41,6 +41,6 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b
return False
if (name := name.parent()) and name.type == "function":
return get_qualified_name(name.module_name, name.full_name) == qualified_function_name
return False # noqa: TRY300
return False
except ValueError:
return False

View file

@ -2704,8 +2704,6 @@ class FunctionOptimizer:
pytest_cmd=self.test_cfg.pytest_cmd,
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
pytest_target_runtime_seconds=testing_time,
pytest_min_loops=1,
pytest_max_loops=1,
test_framework=self.test_cfg.test_framework,
js_project_root=self.test_cfg.js_project_root,
line_profiler_output_file=line_profiler_output_file,

View file

@ -94,7 +94,7 @@ class PicklePatcher:
obj: object,
path: list[str] | None = None, # noqa: ARG004
protocol: int | None = None,
**kwargs: Any, # noqa: ANN401
**kwargs: Any,
) -> tuple[bool, bytes | str]:
"""Try to pickle an object using pickle first, then dill. If both fail, create a placeholder.
@ -123,7 +123,7 @@ class PicklePatcher:
return False, str(e)
@staticmethod
def _recursive_pickle( # noqa: PLR0911
def _recursive_pickle(
obj: object,
max_depth: int,
path: list[str] | None = None,
@ -192,7 +192,7 @@ class PicklePatcher:
error_msg: str, # noqa: ARG004
path: list[str],
protocol: int | None = None,
**kwargs: Any, # noqa: ANN401
**kwargs: Any,
) -> bytes:
"""Handle pickling for dictionary objects.
@ -258,7 +258,7 @@ class PicklePatcher:
error_msg: str, # noqa: ARG004
path: list[str],
protocol: int | None = None,
**kwargs: Any, # noqa: ANN401
**kwargs: Any,
) -> bytes:
"""Handle pickling for sequence types (list, tuple, set).
@ -311,12 +311,7 @@ class PicklePatcher:
@staticmethod
def _handle_object(
obj: object,
max_depth: int,
error_msg: str,
path: list[str],
protocol: int | None = None,
**kwargs: Any, # noqa: ANN401
obj: object, max_depth: int, error_msg: str, path: list[str], protocol: int | None = None, **kwargs: Any
) -> bytes:
"""Handle pickling for custom objects with __dict__.
@ -366,7 +361,7 @@ class PicklePatcher:
if success:
return result
# Fall through to placeholder creation
except Exception: # noqa: S110
except Exception:
pass # Fall through to placeholder creation
# If we get here, just use a placeholder

View file

@ -31,7 +31,7 @@ class PicklePlaceholder:
self.__dict__["error_msg"] = error_msg
self.__dict__["path"] = path if path is not None else []
def __getattr__(self, name) -> Any: # noqa: ANN001, ANN401
def __getattr__(self, name) -> Any:
"""Raise a custom error when any attribute is accessed."""
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
msg = (
@ -40,11 +40,11 @@ class PicklePlaceholder:
)
raise PicklePlaceholderAccessError(msg)
def __setattr__(self, name: str, value: Any) -> None: # noqa: ANN401
def __setattr__(self, name: str, value: Any) -> None:
"""Prevent setting attributes."""
self.__getattr__(name) # This will raise our custom error
def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401, ARG002
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Raise a custom error when the object is called."""
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
msg = (

View file

@ -12,7 +12,7 @@ from codeflash.version import __version__
_posthog = None
def initialize_posthog(enabled: bool = True) -> None: # noqa: FBT001, FBT002
def initialize_posthog(*, enabled: bool = True) -> None:
"""Enable or disable PostHog.
:param enabled: Whether to enable PostHog.

View file

@ -4,7 +4,7 @@ import sentry_sdk
from sentry_sdk.integrations.logging import LoggingIntegration
def init_sentry(enabled: bool = False, exclude_errors: bool = False) -> None: # noqa: FBT001, FBT002
def init_sentry(*, enabled: bool = False, exclude_errors: bool = False) -> None:
if enabled:
sentry_logging = LoggingIntegration(
level=logging.INFO, # Capture info and above as breadcrumbs

View file

@ -232,8 +232,8 @@ def main(args: Namespace | None = None) -> ArgumentParser:
args = process_pyproject_config(args)
args.previous_checkpoint_functions = None
init_sentry(not args.disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(not args.disable_telemetry)
init_sentry(enabled=not args.disable_telemetry, exclude_errors=True)
posthog_cf.initialize_posthog(enabled=not args.disable_telemetry)
from codeflash.optimization import optimizer

View file

@ -76,7 +76,8 @@ class Tracer:
config: dict,
result_pickle_file_path: Path,
functions: list[str] | None = None,
disable: bool = False, # noqa: FBT001, FBT002
*,
disable: bool = False,
project_root: Path | None = None,
max_function_count: int = 256,
timeout: int | None = None, # seconds
@ -309,7 +310,7 @@ class Tracer:
with self.result_pickle_file_path.open("wb") as file:
pickle.dump(pickle_data, file)
def tracer_logic(self, frame: FrameType, event: str) -> None: # noqa: PLR0911
def tracer_logic(self, frame: FrameType, event: str) -> None:
if event != "call":
return
if None is not self.timeout and (time.time() - self.start_time) > self.timeout:
@ -494,7 +495,7 @@ class Tracer:
class_name = arguments["self"].__class__.__name__
elif "cls" in arguments and hasattr(arguments["cls"], "__name__"):
class_name = arguments["cls"].__name__
except Exception: # noqa: S110
except Exception:
pass
fn = (fcode.co_filename, fcode.co_firstlineno, fcode.co_name, class_name)
@ -505,7 +506,7 @@ class Tracer:
timings[fn] = cc, ns + 1, tt, ct, callers
else:
timings[fn] = 0, 0, 0, 0, {}
return 1 # noqa: TRY300
return 1
except Exception:
# Handle any errors gracefully
return 0

View file

@ -28,7 +28,7 @@ def path_belongs_to_site_packages(file_path: Path) -> bool:
def is_git_repo(file_path: str) -> bool:
try:
git.Repo(file_path, search_parent_directories=True)
return True # noqa: TRY300
return True
except git.InvalidGitRepositoryError:
return False

View file

@ -99,7 +99,7 @@ def get_test_info_from_stack(tests_root: str) -> tuple[str, str | None, str, str
return test_module_name, test_class_name, test_name, line_id
def codeflash_capture(function_name: str, tmp_dir_path: str, tests_root: str, is_fto: bool = False) -> Callable: # noqa: FBT001, FBT002
def codeflash_capture(function_name: str, tmp_dir_path: str, tests_root: str, is_fto: bool = False) -> Callable:
"""Define a decorator to instrument the init function, collect test info, and capture the instance state."""
def decorator(wrapped: Callable) -> Callable:

View file

@ -93,7 +93,7 @@ def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]: # no
return _extract_exception_from_message(str(exc))
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
try:
# Handle exceptions specially - before type check to allow wrapper comparison
@ -118,7 +118,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
# Check if new wraps something that matches orig
wrapped_new = _get_wrapped_exception(new)
if wrapped_new is not None and comparator(orig, wrapped_new, superset_obj): # noqa: SIM103
if wrapped_new is not None and comparator(orig, wrapped_new, superset_obj):
return True
return False
@ -242,7 +242,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
continue
if key not in new_keys or not comparator(orig_keys[key], new_keys[key], superset_obj):
return False
return True # noqa: TRY300
return True
except sqlalchemy.exc.NoInspectionAvailable:
pass
@ -366,12 +366,12 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
try:
if HAS_NUMPY and np.isnan(orig):
return np.isnan(new)
except Exception: # noqa: S110
except Exception:
pass
try:
if HAS_NUMPY and np.isinf(orig):
return np.isinf(new)
except Exception: # noqa: S110
except Exception:
pass
if HAS_TORCH:
@ -480,7 +480,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
try:
if hasattr(orig, "__eq__") and str(type(orig.__eq__)) == "<class 'method'>":
return orig == new
except Exception: # noqa: S110
except Exception:
pass
# For class objects
@ -512,7 +512,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
# TODO : Add other types here
logger.warning(f"Unknown comparator input type: {type(orig)}")
sentry_sdk.capture_exception(RuntimeError(f"Unknown comparator input type: {type(orig)}"))
return False # noqa: TRY300
return False
except RecursionError as e:
logger.error(f"RecursionError while comparing objects: {e}")
sentry_sdk.capture_exception(e)

View file

@ -54,16 +54,11 @@ def instrument_codeflash_capture(
def add_codeflash_capture_to_init(
target_classes: set[str],
fto_name: str,
tmp_dir_path: str,
code: str,
tests_root: Path,
is_fto: bool = False, # noqa: FBT001, FBT002
target_classes: set[str], fto_name: str, tmp_dir_path: str, code: str, tests_root: Path, *, is_fto: bool = False
) -> str:
"""Add codeflash_capture decorator to __init__ function in the specified class."""
tree = ast.parse(code)
transformer = InitDecorator(target_classes, fto_name, tmp_dir_path, tests_root, is_fto)
transformer = InitDecorator(target_classes, fto_name, tmp_dir_path, tests_root, is_fto=is_fto)
modified_tree = transformer.visit(tree)
if transformer.inserted_decorator:
ast.fix_missing_locations(modified_tree)
@ -76,12 +71,7 @@ class InitDecorator(ast.NodeTransformer):
"""AST transformer that adds codeflash_capture decorator to specific class's __init__."""
def __init__(
self,
target_classes: set[str],
fto_name: str,
tmp_dir_path: str,
tests_root: Path,
is_fto=False, # noqa: ANN001, FBT002
self, target_classes: set[str], fto_name: str, tmp_dir_path: str, tests_root: Path, *, is_fto: bool = False
) -> None:
self.target_classes = target_classes
self.fto_name = fto_name

View file

@ -12,7 +12,7 @@ import sys
import time as _time_module
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional
from unittest import TestCase
# PyTest Imports
@ -437,7 +437,7 @@ class PytestLoops:
"importlib",
}
def _clear_cache_for_object(obj: Any) -> None: # noqa: ANN401
def _clear_cache_for_object(obj: obj) -> None:
if obj in processed_functions:
return
processed_functions.add(obj)
@ -469,9 +469,9 @@ class PytestLoops:
for _, obj in inspect.getmembers(module):
if callable(obj):
_clear_cache_for_object(obj)
except Exception: # noqa: S110
except Exception:
pass
except Exception: # noqa: S110
except Exception:
pass
def _set_nodeid(self, nodeid: str, count: int) -> str:
@ -581,7 +581,7 @@ class PytestLoops:
os.environ["CODEFLASH_TEST_FUNCTION"] = test_function_name
@pytest.hookimpl(trylast=True)
def pytest_runtest_teardown(self, item: pytest.Item) -> None: # noqa: ARG002
def pytest_runtest_teardown(self, item: pytest.Item) -> None:
"""Clean up test context environment variables after each test."""
for var in ["CODEFLASH_TEST_MODULE", "CODEFLASH_TEST_CLASS", "CODEFLASH_TEST_FUNCTION"]:
os.environ.pop(var, None)

View file

@ -258,8 +258,8 @@ def run_line_profile_tests(
*,
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME_EFFECTIVE,
pytest_timeout: int | None = None,
pytest_min_loops: int = 5, # noqa: ARG001
pytest_max_loops: int = 100_000, # noqa: ARG001
pytest_min_loops: int = 5,
pytest_max_loops: int = 100_000,
js_project_root: Path | None = None,
line_profiler_output_file: Path | None = None,
) -> tuple[Path, subprocess.CompletedProcess]:

View file

@ -28,7 +28,7 @@ def generate_tests(
test_index: int,
test_path: Path,
test_perf_path: Path,
is_numerical_code: bool | None = None, # noqa: FBT001
is_numerical_code: bool | None = None,
) -> tuple[str, str, str, Path, Path] | None:
# TODO: Sometimes this recreates the original Class definition. This overrides and messes up the original
# class import. Remove the recreation of the class definition

View file

@ -263,12 +263,13 @@ ignore = [
"TD007",
"D417",
"D401",
"S110", # try-except-pass - we do this a lot
"ARG002", # Unused method argument
# Added for multi-language branch
"FBT001", # Boolean positional argument
"FBT002", # Boolean default positional argument
"ANN401", # typing.Any disallowed
"ARG001", # Unused function argument (common in abstract/interface methods)
"ARG002", # Unused method argument
"TRY300", # Consider moving to else block
"TRY401", # Redundant exception in logging.exception
"PLR0911", # Too many return statements
@ -277,7 +278,6 @@ ignore = [
"SIM102", # Nested if statements
"SIM103", # Return negated condition
"ANN001", # Missing type annotation
"S110", # try-except-pass
"PLC0206", # Dictionary items
"S314", # XML parsing (acceptable for dev tool)
"S608", # SQL injection (internal use only)