fix for conftest issue

This commit is contained in:
aseembits93 2025-06-13 16:08:07 -07:00
parent 5651629c8b
commit 771ba90932

View file

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Optional, TypeVar
import isort
import libcst as cst
import libcst.matchers as m
from libcst.metadata import PositionProvider
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
@ -37,6 +38,34 @@ def normalize_code(code: str) -> str:
return ast.unparse(normalize_node(ast.parse(code)))
class AddRequestArgument(cst.CSTTransformer):
METADATA_DEPENDENCIES = (PositionProvider,)
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
args = updated_node.params.params
arg_names = {arg.name.value for arg in args}
# Skip if 'request' is already present
if "request" in arg_names:
return updated_node
# Create a new 'request' param
request_param = cst.Param(name=cst.Name("request"))
# Add 'request' as the first argument (after 'self' or 'cls' if needed)
if args:
first_arg = args[0].name.value
if first_arg in {"self", "cls"}:
new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005
else:
new_params = [request_param] + list(args) # noqa: RUF005
else:
new_params = [request_param]
new_param_list = updated_node.params.with_changes(params=new_params)
return updated_node.with_changes(params=new_param_list)
class PytestMarkAdder(cst.CSTTransformer):
"""Transformer that adds pytest marks to test functions."""
@ -139,8 +168,10 @@ class AutouseFixtureModifier(cst.CSTTransformer):
def disable_autouse(test_path: Path) -> str:
file_content = test_path.read_text(encoding="utf-8")
module = cst.parse_module(file_content)
add_request_argument = AddRequestArgument()
disable_autouse_fixture = AutouseFixtureModifier()
modified_module = module.visit(disable_autouse_fixture)
modified_module = module.visit(add_request_argument)
modified_module = modified_module.visit(disable_autouse_fixture)
test_path.write_text(modified_module.code, encoding="utf-8")
return file_content