new tests

This commit is contained in:
aseembits93 2025-06-13 17:27:45 -07:00
parent 771ba90932
commit da9df78062
2 changed files with 547 additions and 54 deletions

View file

@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Optional, TypeVar
import isort
import libcst as cst
import libcst.matchers as m
from libcst.metadata import PositionProvider
from codeflash.cli_cmds.console import logger
@ -41,29 +40,50 @@ def normalize_code(code: str) -> str:
class AddRequestArgument(cst.CSTTransformer):
METADATA_DEPENDENCIES = (PositionProvider,)
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
args = updated_node.params.params
arg_names = {arg.name.value for arg in args}
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
# Matcher for '@fixture' or '@pytest.fixture'
for decorator in original_node.decorators:
dec = decorator.decorator
# Skip if 'request' is already present
if "request" in arg_names:
return updated_node
if isinstance(dec, cst.Call):
func_name = ""
if isinstance(dec.func, cst.Attribute) and isinstance(dec.func.value, cst.Name):
if dec.func.attr.value == "fixture" and dec.func.value.value == "pytest":
func_name = "pytest.fixture"
elif isinstance(dec.func, cst.Name) and dec.func.value == "fixture":
func_name = "fixture"
# Create a new 'request' param
request_param = cst.Param(name=cst.Name("request"))
if func_name:
for arg in dec.args:
if (
arg.keyword
and arg.keyword.value == "autouse"
and isinstance(arg.value, cst.Name)
and arg.value.value == "True"
):
args = updated_node.params.params
arg_names = {arg.name.value for arg in args}
# Add 'request' as the first argument (after 'self' or 'cls' if needed)
if args:
first_arg = args[0].name.value
if first_arg in {"self", "cls"}:
new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005
else:
new_params = [request_param] + list(args) # noqa: RUF005
else:
new_params = [request_param]
# Skip if 'request' is already present
if "request" in arg_names:
return updated_node
new_param_list = updated_node.params.with_changes(params=new_params)
return updated_node.with_changes(params=new_param_list)
# Create a new 'request' param
request_param = cst.Param(name=cst.Name("request"))
# Add 'request' as the first argument (after 'self' or 'cls' if needed)
if args:
first_arg = args[0].name.value
if first_arg in {"self", "cls"}:
new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005
else:
new_params = [request_param] + list(args) # noqa: RUF005
else:
new_params = [request_param]
new_param_list = updated_node.params.with_changes(params=new_params)
return updated_node.with_changes(params=new_param_list)
return updated_node
class PytestMarkAdder(cst.CSTTransformer):
@ -135,33 +155,41 @@ class PytestMarkAdder(cst.CSTTransformer):
class AutouseFixtureModifier(cst.CSTTransformer):
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
# Matcher for '@fixture' or '@pytest.fixture'
fixture_decorator_func = m.Name("fixture") | m.Attribute(value=m.Name("pytest"), attr=m.Name("fixture"))
for decorator in original_node.decorators:
if m.matches(
decorator,
m.Decorator(
decorator=m.Call(
func=fixture_decorator_func, args=[m.Arg(value=m.Name("True"), keyword=m.Name("autouse"))]
)
),
):
# Found a matching fixture with autouse=True
dec = decorator.decorator
# 1. The original body of the function will become the 'else' block.
# updated_node.body is an IndentedBlock, which is what cst.Else expects.
else_block = cst.Else(body=updated_node.body)
if isinstance(dec, cst.Call):
func_name = ""
if isinstance(dec.func, cst.Attribute) and isinstance(dec.func.value, cst.Name):
if dec.func.attr.value == "fixture" and dec.func.value.value == "pytest":
func_name = "pytest.fixture"
elif isinstance(dec.func, cst.Name) and dec.func.value == "fixture":
func_name = "fixture"
# 2. Create the new 'if' block that will exit the fixture early.
if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")')
yield_statement = cst.parse_statement("yield")
if_body = cst.IndentedBlock(body=[yield_statement])
if func_name:
for arg in dec.args:
if (
arg.keyword
and arg.keyword.value == "autouse"
and isinstance(arg.value, cst.Name)
and arg.value.value == "True"
):
# Found a matching fixture with autouse=True
# 3. Construct the full if/else statement.
new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block)
# 1. The original body of the function will become the 'else' block.
# updated_node.body is an IndentedBlock, which is what cst.Else expects.
else_block = cst.Else(body=updated_node.body)
# 4. Replace the entire function's body with our new single statement.
return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement]))
# 2. Create the new 'if' block that will exit the fixture early.
if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")')
yield_statement = cst.parse_statement("yield")
if_body = cst.IndentedBlock(body=[yield_statement])
# 3. Construct the full if/else statement.
new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block)
# 4. Replace the entire function's body with our new single statement.
return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement]))
return updated_node

View file

@ -1,6 +1,6 @@
from __future__ import annotations
import libcst as cst
from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder
from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder, AddRequestArgument
import dataclasses
import os
from collections import defaultdict
@ -2564,15 +2564,15 @@ from pathlib import Path
class TestIntegration:
"""Integration tests for both transformers working together."""
"""Integration tests for all transformers working together."""
def test_both_transformers_together(self):
"""Test that both transformers can work on the same code."""
def test_all_transformers_together(self):
"""Test that all three transformers can work on the same code."""
source_code = '''
import pytest
@pytest.fixture(autouse=True)
def my_fixture(request):
def my_fixture():
yield "value"
def test_something():
@ -2593,16 +2593,481 @@ def my_fixture(request):
def test_something():
assert True
'''
# First apply AutouseFixtureModifier
# First apply AddRequestArgument
module = cst.parse_module(source_code)
autouse_modifier = AutouseFixtureModifier()
modified_module = module.visit(autouse_modifier)
request_adder = AddRequestArgument()
modified_module = module.visit(request_adder)
# Then apply PytestMarkAdder
# Then apply AutouseFixtureModifier
autouse_modifier = AutouseFixtureModifier()
modified_module = modified_module.visit(autouse_modifier)
# Finally apply PytestMarkAdder
mark_adder = PytestMarkAdder("codeflash_no_autouse")
final_module = modified_module.visit(mark_adder)
code = final_module.code
# Should have both modifications
assert code==expected_code
# Compare complete strings
assert final_module.code == expected_code
def test_transformers_with_existing_request_parameter(self):
"""Test transformers when request parameter already exists."""
source_code = '''
import pytest
@pytest.fixture(autouse=True)
def my_fixture(request):
setup_code()
yield "value"
cleanup_code()
def test_something():
assert True
'''
expected_code = '''
import pytest
@pytest.fixture(autouse=True)
@pytest.mark.codeflash_no_autouse
def my_fixture(request):
if request.node.get_closest_marker("codeflash_no_autouse"):
yield
else:
setup_code()
yield "value"
cleanup_code()
@pytest.mark.codeflash_no_autouse
def test_something():
assert True
'''
# Apply all transformers in sequence
module = cst.parse_module(source_code)
request_adder = AddRequestArgument()
modified_module = module.visit(request_adder)
autouse_modifier = AutouseFixtureModifier()
modified_module = modified_module.visit(autouse_modifier)
mark_adder = PytestMarkAdder("codeflash_no_autouse")
final_module = modified_module.visit(mark_adder)
# Compare complete strings
assert final_module.code == expected_code
def test_transformers_with_self_parameter(self):
"""Test transformers when fixture has self parameter."""
source_code = '''
import pytest
@pytest.fixture(autouse=True)
def my_fixture(self):
yield "value"
def test_something():
assert True
'''
expected_code = '''
import pytest
@pytest.fixture(autouse=True)
@pytest.mark.codeflash_no_autouse
def my_fixture(self, request):
if request.node.get_closest_marker("codeflash_no_autouse"):
yield
else:
yield "value"
@pytest.mark.codeflash_no_autouse
def test_something():
assert True
'''
# Apply all transformers in sequence
module = cst.parse_module(source_code)
request_adder = AddRequestArgument()
modified_module = module.visit(request_adder)
autouse_modifier = AutouseFixtureModifier()
modified_module = modified_module.visit(autouse_modifier)
mark_adder = PytestMarkAdder("codeflash_no_autouse")
final_module = modified_module.visit(mark_adder)
# Compare complete strings
assert final_module.code == expected_code
def test_transformers_with_multiple_fixtures(self):
"""Test transformers with multiple autouse fixtures."""
source_code = '''
import pytest
@pytest.fixture(autouse=True)
def fixture_one():
yield "one"
@pytest.fixture(autouse=True)
def fixture_two(self, param):
yield "two"
@pytest.fixture
def regular_fixture():
return "regular"
def test_something():
assert True
'''
expected_code = '''
import pytest
@pytest.fixture(autouse=True)
@pytest.mark.codeflash_no_autouse
def fixture_one(request):
if request.node.get_closest_marker("codeflash_no_autouse"):
yield
else:
yield "one"
@pytest.fixture(autouse=True)
@pytest.mark.codeflash_no_autouse
def fixture_two(self, request, param):
if request.node.get_closest_marker("codeflash_no_autouse"):
yield
else:
yield "two"
@pytest.fixture
@pytest.mark.codeflash_no_autouse
def regular_fixture():
return "regular"
@pytest.mark.codeflash_no_autouse
def test_something():
assert True
'''
# Apply all transformers in sequence
module = cst.parse_module(source_code)
request_adder = AddRequestArgument()
modified_module = module.visit(request_adder)
autouse_modifier = AutouseFixtureModifier()
modified_module = modified_module.visit(autouse_modifier)
mark_adder = PytestMarkAdder("codeflash_no_autouse")
final_module = modified_module.visit(mark_adder)
# Compare complete strings
assert final_module.code == expected_code
class TestAddRequestArgument:
"""Test cases for AddRequestArgument transformer."""
def test_adds_request_to_autouse_fixture_no_existing_args(self):
"""Test adding request argument to autouse fixture with no existing arguments."""
source_code = '''
@fixture(autouse=True)
def my_fixture():
pass
'''
expected = '''
@fixture(autouse=True)
def my_fixture(request):
pass
'''
module = cst.parse_module(source_code)
transformer = AddRequestArgument()
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_adds_request_to_pytest_fixture_autouse(self):
"""Test adding request argument to pytest.fixture with autouse=True."""
source_code = '''
@pytest.fixture(autouse=True)
def my_fixture():
pass
'''
expected = '''
@pytest.fixture(autouse=True)
def my_fixture(request):
pass
'''
module = cst.parse_module(source_code)
transformer = AddRequestArgument()
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_adds_request_after_self_parameter(self):
"""Test adding request argument after self parameter."""
source_code = '''
@fixture(autouse=True)
def my_fixture(self):
pass
'''
expected = '''
@fixture(autouse=True)
def my_fixture(self, request):
pass
'''
module = cst.parse_module(source_code)
transformer = AddRequestArgument()
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_adds_request_after_cls_parameter(self):
"""Test adding request argument after cls parameter."""
source_code = '''
@fixture(autouse=True)
def my_fixture(cls):
pass
'''
expected = '''
@fixture(autouse=True)
def my_fixture(cls, request):
pass
'''
module = cst.parse_module(source_code)
transformer = AddRequestArgument()
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_adds_request_before_other_parameters(self):
"""Test adding request argument before other parameters (not self/cls)."""
source_code = '''
@fixture(autouse=True)
def my_fixture(param1, param2):
pass
'''
expected = '''
@fixture(autouse=True)
def my_fixture(request, param1, param2):
pass
'''
module = cst.parse_module(source_code)
transformer = AddRequestArgument()
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_adds_request_after_self_with_other_parameters(self):
"""Test adding request argument after self with other parameters."""
source_code = '''
@fixture(autouse=True)
def my_fixture(self, param1, param2):
pass
'''
expected = '''
@fixture(autouse=True)
def my_fixture(self, request, param1, param2):
pass
'''
module = cst.parse_module(source_code)
transformer = AddRequestArgument()
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_skips_when_request_already_present(self):
"""Test that request argument is not added when already present."""
source_code = '''
@fixture(autouse=True)
def my_fixture(request):
pass
'''
expected = '''
@fixture(autouse=True)
def my_fixture(request):
pass
'''
module = cst.parse_module(source_code)
transformer = AddRequestArgument()
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_skips_when_request_present_with_other_args(self):
"""Test that request argument is not added when already present with other args."""
source_code = '''
@fixture(autouse=True)
def my_fixture(self, request, param1):
pass
'''
expected = '''
@fixture(autouse=True)
def my_fixture(self, request, param1):
pass
'''
module = cst.parse_module(source_code)
transformer = AddRequestArgument()
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_ignores_non_autouse_fixture(self):
"""Test that non-autouse fixtures are not modified."""
source_code = '''
@fixture
def my_fixture():
pass
'''
expected = '''
@fixture
def my_fixture():
pass
'''
module = cst.parse_module(source_code)
transformer = AddRequestArgument()
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_ignores_fixture_with_autouse_false(self):
"""Test that fixtures with autouse=False are not modified."""
source_code = '''
@fixture(autouse=False)
def my_fixture():
pass
'''
expected = '''
@fixture(autouse=False)
def my_fixture():
pass
'''
module = cst.parse_module(source_code)
transformer = AddRequestArgument()
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_ignores_regular_function(self):
"""Test that regular functions are not modified."""
source_code = '''
def my_function():
pass
'''
expected = '''
def my_function():
pass
'''
module = cst.parse_module(source_code)
transformer = AddRequestArgument()
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_handles_multiple_autouse_fixtures(self):
"""Test handling multiple autouse fixtures in the same module."""
source_code = '''
@fixture(autouse=True)
def fixture1():
pass
@pytest.fixture(autouse=True)
def fixture2(self):
pass
@fixture(autouse=True)
def fixture3(request):
pass
'''
expected = '''
@fixture(autouse=True)
def fixture1(request):
pass
@pytest.fixture(autouse=True)
def fixture2(self, request):
pass
@fixture(autouse=True)
def fixture3(request):
pass
'''
module = cst.parse_module(source_code)
transformer = AddRequestArgument()
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_handles_fixture_with_other_decorators(self):
"""Test handling fixture with other decorators."""
source_code = '''
@some_decorator
@fixture(autouse=True)
@another_decorator
def my_fixture():
pass
'''
expected = '''
@some_decorator
@fixture(autouse=True)
@another_decorator
def my_fixture(request):
pass
'''
module = cst.parse_module(source_code)
transformer = AddRequestArgument()
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_preserves_function_body_and_docstring(self):
"""Test that function body and docstring are preserved."""
source_code = '''
@fixture(autouse=True)
def my_fixture():
"""This is a docstring."""
x = 1
y = 2
return x + y
'''
expected = '''
@fixture(autouse=True)
def my_fixture(request):
"""This is a docstring."""
x = 1
y = 2
return x + y
'''
module = cst.parse_module(source_code)
transformer = AddRequestArgument()
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()
def test_handles_fixture_with_additional_arguments(self):
"""Test handling fixture with additional keyword arguments."""
source_code = '''
@fixture(autouse=True, scope="session")
def my_fixture():
pass
'''
expected = '''
@fixture(autouse=True, scope="session")
def my_fixture(request):
pass
'''
module = cst.parse_module(source_code)
transformer = AddRequestArgument()
modified_module = module.visit(transformer)
assert modified_module.code.strip() == expected.strip()