fix for conftest issue
This commit is contained in:
parent
5651629c8b
commit
771ba90932
1 changed files with 32 additions and 1 deletions
|
|
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Optional, TypeVar
|
|||
import isort
|
||||
import libcst as cst
|
||||
import libcst.matchers as m
|
||||
from libcst.metadata import PositionProvider
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
|
||||
|
|
@ -37,6 +38,34 @@ def normalize_code(code: str) -> str:
|
|||
return ast.unparse(normalize_node(ast.parse(code)))
|
||||
|
||||
|
||||
class AddRequestArgument(cst.CSTTransformer):
|
||||
METADATA_DEPENDENCIES = (PositionProvider,)
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
|
||||
args = updated_node.params.params
|
||||
arg_names = {arg.name.value for arg in args}
|
||||
|
||||
# Skip if 'request' is already present
|
||||
if "request" in arg_names:
|
||||
return updated_node
|
||||
|
||||
# Create a new 'request' param
|
||||
request_param = cst.Param(name=cst.Name("request"))
|
||||
|
||||
# Add 'request' as the first argument (after 'self' or 'cls' if needed)
|
||||
if args:
|
||||
first_arg = args[0].name.value
|
||||
if first_arg in {"self", "cls"}:
|
||||
new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005
|
||||
else:
|
||||
new_params = [request_param] + list(args) # noqa: RUF005
|
||||
else:
|
||||
new_params = [request_param]
|
||||
|
||||
new_param_list = updated_node.params.with_changes(params=new_params)
|
||||
return updated_node.with_changes(params=new_param_list)
|
||||
|
||||
|
||||
class PytestMarkAdder(cst.CSTTransformer):
|
||||
"""Transformer that adds pytest marks to test functions."""
|
||||
|
||||
|
|
@ -139,8 +168,10 @@ class AutouseFixtureModifier(cst.CSTTransformer):
|
|||
def disable_autouse(test_path: Path) -> str:
|
||||
file_content = test_path.read_text(encoding="utf-8")
|
||||
module = cst.parse_module(file_content)
|
||||
add_request_argument = AddRequestArgument()
|
||||
disable_autouse_fixture = AutouseFixtureModifier()
|
||||
modified_module = module.visit(disable_autouse_fixture)
|
||||
modified_module = module.visit(add_request_argument)
|
||||
modified_module = modified_module.visit(disable_autouse_fixture)
|
||||
test_path.write_text(modified_module.code, encoding="utf-8")
|
||||
return file_content
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue