codeflash-internal/django/aiservice/tests/testgen_postprocessing/test_remove_asserts.py
Kevin Turcios 0444e32f77
fix: CST tree handling and testgen pipeline improvements (#2310)
## Summary
- Fix CST tree corruption issues that caused 'NoneType' object has no
attribute 'visit' errors
- Consolidate testgen postprocessing into a single pipeline with
tuple-based pattern
- Improve markdown code extraction to prefer filepath-annotated blocks
- Add diagnostic context to optimization failure logs

## Changes
- Handle empty `SimpleStatementLine` and `StatementHandler` body to
prevent malformed CST
- Add trace_id logging to optimization and import failure paths
- Refactor testgen postprocessing into consolidated pipeline
- Fix code extraction for LLM responses with multiple code blocks

## Test plan
- [x] Added integration tests for full testgen pipeline
- [x] Added tests for markdown extraction with filepath preference
- [x] Existing tests pass

---------

Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>
2026-01-26 23:57:55 -05:00

890 lines
33 KiB
Python

from libcst import Pass, RemoveFromParent, SimpleStatementLine, SimpleStatementSuite
from libcst import parse_module as parse_module_to_cst
from aiservice.models.functions_to_optimize import FunctionParent, FunctionToOptimize
from testgen.postprocessing.removeassert_transformer import (
RemoveAssertTransformer,
StatementHandler,
remove_asserts_from_test,
)
def test_remove_asserts() -> None:
original_test = """# imports
import pytest # used for our unit tests
from code_to_optimize.bubble_sort import sorter
# unit tests
def test_basic_functionality():
# Simple Unsorted List
res = sorter([3, 1, 2])
assert res == [1, 2, 3]
res = sorter([5, 3, 8, 4, 2])
assert res == [2, 3, 4, 5, 8]
# Already Sorted List
assert sorter([1, 2, 3]) == [1, 2, 3]
assert sorter([2, 4, 6, 8]) == [2, 4, 6, 8]
# Reverse Sorted List
assert sorter([3, 2, 1]) == [1, 2, 3]
assert sorter([9, 7, 5, 3, 1]) == [1, 3, 5, 7, 9]
assert [1,2,3,4,5] == sorter([1,2,3,4,5])
with open("file.txt", "w") as f:
assert sorter([1, 2, 3, 4, 5]) == [1, 2, 3, 4, 5]
"""
expected = """# imports
from code_to_optimize.bubble_sort import sorter
import pytest # used for our unit tests
# unit tests
def test_basic_functionality():
# Simple Unsorted List
codeflash_output = sorter([3, 1, 2]); res = codeflash_output
codeflash_output = sorter([5, 3, 8, 4, 2]); res = codeflash_output
# Already Sorted List
codeflash_output = sorter([1, 2, 3])
codeflash_output = sorter([2, 4, 6, 8])
# Reverse Sorted List
codeflash_output = sorter([3, 2, 1])
codeflash_output = sorter([9, 7, 5, 3, 1])
codeflash_output = sorter([1,2,3,4,5])
with open("file.txt", "w") as f:
codeflash_output = sorter([1, 2, 3, 4, 5])
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
"""
function_to_optimize = FunctionToOptimize(
function_name="sorter", file_path="bubble_sort.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="code_to_optimize.bubble_sort",
).code
assert result == expected
def test_remove_asserts_class_functions() -> None:
original_test = """# imports
import pytest # used for our unit tests
from code_to_optimize.bubble_sort import BubbleSortClass
bubble_sort = BubbleSortClass()
# unit tests
def test_basic_functionality():
# Simple Unsorted List
res = bubble_sort.sorter([3, 1, 2])
assert res == [1, 2, 3]
res = bubble_sort.sorter([5, 3, 8, 4, 2])
assert res == [2, 3, 4, 5, 8]
# Already Sorted List
assert bubble_sort.sorter([1, 2, 3]) == [1, 2, 3]
assert bubble_sort.sorter([2, 4, 6, 8]) == [2, 4, 6, 8]"""
expected = """# imports
from code_to_optimize.bubble_sort import BubbleSortClass
import pytest # used for our unit tests
bubble_sort = BubbleSortClass()
# unit tests
def test_basic_functionality():
# Simple Unsorted List
codeflash_output = bubble_sort.sorter([3, 1, 2]); res = codeflash_output
codeflash_output = bubble_sort.sorter([5, 3, 8, 4, 2]); res = codeflash_output
# Already Sorted List
codeflash_output = bubble_sort.sorter([1, 2, 3])
codeflash_output = bubble_sort.sorter([2, 4, 6, 8])
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
function_to_optimize = FunctionToOptimize(
function_name="sorter",
file_path="bubble_sort.py",
parents=[FunctionParent(name="BubbleSortClass", type="ClassDef")],
starting_line=None,
ending_line=None,
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="code_to_optimize.bubble_sort",
).code
assert result == expected
def test_remove_asserts_unittest() -> None:
original_test = """import unittest
from code_to_optimize.bubble_sort import sorter
class TestPigLatin(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1, 0]
self.assertEqual(sorter(input), [0, 1, 2, 3, 4, 5])
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
output = sorter(input)
self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
input = list(reversed(range(5000)))
output = sorter(input)
self.assertEqual(output, list(range(5000)))
self.assertEqual(output, sorter(input))
with open("file.txt", "w") as f:
self.assertEqual(sorter(input), list(range(5000)))
"""
expected = """from code_to_optimize.bubble_sort import sorter
import unittest
class TestPigLatin(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1, 0]
codeflash_output = sorter(input)
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
codeflash_output = sorter(input); output = codeflash_output
input = list(reversed(range(5000)))
codeflash_output = sorter(input); output = codeflash_output
codeflash_output = sorter(input)
with open("file.txt", "w") as f:
codeflash_output = sorter(input)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
"""
function_to_optimize = FunctionToOptimize(
function_name="sorter", file_path="bubble_sort.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="code_to_optimize.bubble_sort",
).code
assert result == expected
def test_remove_single_arg_asserts_unittest() -> None:
original_test = """import unittest
from code_to_optimize.bubble_sort import sorter
class TestBubbleSort(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1]
output = sorter(input)
self.assertTrue(output)
self.assertFalse(sorter([]))
input = [3, 2, 1]
self.assertTrue(sorter(input))
self.assertIsNone(sorter([]))
self.assertIsNotNone(sorter([1]))
"""
expected = """from code_to_optimize.bubble_sort import sorter
import unittest
class TestBubbleSort(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1]
codeflash_output = sorter(input); output = codeflash_output
codeflash_output = sorter([])
input = [3, 2, 1]
codeflash_output = sorter(input)
codeflash_output = sorter([])
codeflash_output = sorter([1])
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
"""
function_to_optimize = FunctionToOptimize(
function_name="sorter", file_path="bubble_sort.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="code_to_optimize.bubble_sort",
).code
assert result == expected
def test_remove_irrelevant_assert() -> None:
original_test = """def test_relative_validity_no_tree():
hdbscan = HDBSCAN()
assert hdbscan.relative_validity_() == 0.5
result = 5
assert result == 5
assert 1 == True"""
expected = """from code_to_optimize.bubble_sort import HDBSCAN
def test_relative_validity_no_tree():
hdbscan = HDBSCAN()
codeflash_output = hdbscan.relative_validity_()
result = 5
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
function_to_optimize = FunctionToOptimize(
function_name="relative_validity_",
file_path="/tmp/path",
parents=[FunctionParent(name="HDBSCAN", type="ClassDef")],
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="code_to_optimize.bubble_sort",
).code
assert result == expected
def test_remove_asserts_one_line_functions() -> None:
original_test = """def test_short_function(): x = 5; y = sorter([3,2,1]); assert y == [1,2,3]; z = 10"""
expected = """from code_to_optimize.bubble_sort import sorter
def test_short_function(): x = 5; codeflash_output = sorter([3,2,1]); y = codeflash_output; z = 10
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
function_to_optimize = FunctionToOptimize(
function_name="sorter", file_path="bubble_sort.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="code_to_optimize.bubble_sort",
).code
assert result == expected
def test_remove_asserts_one_line_statements() -> None:
original_test = """def test_short_function():
x = 5; y = sorter([3,2,1]); assert y == [1,2,3]; z = 10
a = sorter([5,4,3]); assert a == [3,4,5]"""
expected = """from code_to_optimize.bubble_sort import sorter
def test_short_function():
x = 5; codeflash_output = sorter([3,2,1]); y = codeflash_output; z = 10
codeflash_output = sorter([5,4,3]); a = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
function_to_optimize = FunctionToOptimize(
function_name="sorter", file_path="bubble_sort.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="code_to_optimize.bubble_sort",
).code
assert result == expected
def test_remove_asserts_mixed_statement_types() -> None:
original_test = """def test_mixed_statements():
# First part has one-line statements
x = 5; y = sorter([3,2,1]); assert y == [1,2,3]
# Second part has regular statements
codeflash_output = sorter([5,4,3])
assert result == [3,4,5]"""
expected = """from code_to_optimize.bubble_sort import sorter
def test_mixed_statements():
# First part has one-line statements
x = 5; codeflash_output = sorter([3,2,1]); y = codeflash_output
# Second part has regular statements
codeflash_output = sorter([5,4,3])
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
function_to_optimize = FunctionToOptimize(
function_name="sorter", file_path="bubble_sort.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="code_to_optimize.bubble_sort",
).code
assert result == expected
def test_function_with_pass() -> None:
original_test = """def test_large_callable_object():
class LargeClass:
async def __call__(self):
pass
def method1(self): pass
def method2(self): pass
# Add more methods if needed to simulate a large class
assert is_coroutine_callable(LargeClass()) == True"""
expected = """from some_file import is_coroutine_callable
def test_large_callable_object():
class LargeClass:
async def __call__(self):
pass
def method1(self): pass
def method2(self): pass
# Add more methods if needed to simulate a large class
codeflash_output = is_coroutine_callable(LargeClass())
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
function_to_optimize = FunctionToOptimize(
function_name="is_coroutine_callable",
file_path="some_file.py",
parents=[],
starting_line=None,
ending_line=None,
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="some_file",
).code
assert result == expected
def test_assert_not_with_primitive() -> None:
original_test = """def test_primitive_types():
assert not is_coroutine_callable(42)
assert not is_coroutine_callable("string")
assert not is_coroutine_callable([1, 2, 3])"""
expected = """from some_file import is_coroutine_callable
def test_primitive_types():
codeflash_output = not is_coroutine_callable(42)
codeflash_output = not is_coroutine_callable("string")
codeflash_output = not is_coroutine_callable([1, 2, 3])
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
function_to_optimize = FunctionToOptimize(
function_name="is_coroutine_callable",
file_path="some_file.py",
parents=[],
starting_line=None,
ending_line=None,
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="some_file",
).code
assert result == expected
def test_assert_with_primitive() -> None:
original_test = """def test_primitive_types():
assert is_coroutine_callable(42)
assert is_coroutine_callable("string")
assert is_coroutine_callable([1, 2, 3])"""
expected = """from some_file import is_coroutine_callable
def test_primitive_types():
codeflash_output = is_coroutine_callable(42)
codeflash_output = is_coroutine_callable("string")
codeflash_output = is_coroutine_callable([1, 2, 3])
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
function_to_optimize = FunctionToOptimize(
function_name="is_coroutine_callable",
file_path="some_file.py",
parents=[],
starting_line=None,
ending_line=None,
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="some_file",
).code
assert result == expected
def test_function_for_list_comprehension() -> None:
original_test = """def test_list_comprehension():
assert [some_function(x) for x in range(5)] == [0, 1, 2, 3, 4]"""
expected = """from some_file import some_function
def test_list_comprehension():
codeflash_output = [some_function(x) for x in range(5)]
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
function_to_optimize = FunctionToOptimize(
function_name="some_function", file_path="some_file.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="some_file",
).code
assert result == expected
def test_singleline_function_for_list_comprehension() -> None:
original_test = """def test_list_comprehension(): assert [some_function(x) for x in range(5)] == [0, 1, 2, 3, 4]; assert [some_function(y) for x in range(4)] == [1,3,4,5]"""
expected = """from some_file import some_function
def test_list_comprehension(): codeflash_output = [some_function(x) for x in range(5)]; codeflash_output = [some_function(y) for x in range(4)]
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
function_to_optimize = FunctionToOptimize(
function_name="some_function", file_path="some_file.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="some_file",
).code
assert result == expected
def test_unittest_function_for_list_comprehension() -> None:
original_test = """def test_unittest_style():
self.assertEqual([some_function(x) for x in range(5)], [0, 1, 2, 3, 4])"""
expected = """from some_file import some_function
def test_unittest_style():
codeflash_output = [some_function(x) for x in range(5)]
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
function_to_optimize = FunctionToOptimize(
function_name="some_function", file_path="some_file.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="some_file",
).code
assert result == expected
def test_remove_asserts_import_star() -> None:
original_test = """from x import *
from code_to_optimize.bubble_sort import sorter
from code_to_optimize.bubble_sort import *
from y import *
def test_basic_functionality():
assert sorter([3, 1, 2]) == [1, 2, 3]
"""
expected = """from code_to_optimize.bubble_sort import sorter
from x import *
from code_to_optimize.bubble_sort import *
from y import *
def test_basic_functionality():
codeflash_output = sorter([3, 1, 2])
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
"""
function_to_optimize = FunctionToOptimize(
function_name="sorter", file_path="bubble_sort.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="code_to_optimize.bubble_sort",
).code
assert result == expected
def test_post_processing_iteration() -> None:
test_source = '''
import pytest # used for our unit tests
def extract_input_variables(nodes: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Extracts input variables from the template and adds them to the input_variables field."""
for node in nodes:
with contextlib.suppress(Exception):
if "input_variables" in node["data"]["node"]["template"]:
if node["data"]["node"]["template"]["_type"] == "prompt":
variables = re.findall(
r"\\{(.*?)\\}",
node["data"]["node"]["template"]["template"]["value"],
)
elif node["data"]["node"]["template"]["_type"] == "few_shot":
variables = re.findall(
r"\\{(.*?)\\}",
node["data"]["node"]["template"]["prefix"]["value"]
+ node["data"]["node"]["template"]["suffix"]["value"],
)
else:
variables = []
node["data"]["node"]["template"]["input_variables"]["value"] = variables
return nodes
def test_large_list_of_nodes():
# A large list containing thousands of nodes to assess performance
nodes = [{"data": {"node": {"template": {"_type": "prompt", "template": {"value": "Hello, {name}!"}, "input_variables": {"value": []}}}}}] * 1000
result = extract_input_variables(nodes)
for res in result:
assert res["data"]["node"]["template"]["input_variables"]["value"] == ["name"]'''
expected = '''
from some_file import extract_input_variables
import pytest # used for our unit tests
def extract_input_variables(nodes: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Extracts input variables from the template and adds them to the input_variables field."""
for node in nodes:
with contextlib.suppress(Exception):
if "input_variables" in node["data"]["node"]["template"]:
if node["data"]["node"]["template"]["_type"] == "prompt":
variables = re.findall(
r"\\{(.*?)\\}",
node["data"]["node"]["template"]["template"]["value"],
)
elif node["data"]["node"]["template"]["_type"] == "few_shot":
variables = re.findall(
r"\\{(.*?)\\}",
node["data"]["node"]["template"]["prefix"]["value"]
+ node["data"]["node"]["template"]["suffix"]["value"],
)
else:
variables = []
node["data"]["node"]["template"]["input_variables"]["value"] = variables
return nodes
def test_large_list_of_nodes():
# A large list containing thousands of nodes to assess performance
nodes = [{"data": {"node": {"template": {"_type": "prompt", "template": {"value": "Hello, {name}!"}, "input_variables": {"value": []}}}}}] * 1000
codeflash_output = extract_input_variables(nodes); result = codeflash_output
for res in result:
pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.'''
assert (
remove_asserts_from_test(
module=parse_module_to_cst(test_source),
function_to_optimize=FunctionToOptimize(
function_name="extract_input_variables",
file_path="some_file.py",
parents=[],
starting_line=None,
ending_line=None,
),
helper_function_names=[],
module_path="some_file",
).code
== expected
)
def test_remove_asserts_numpy_pandas() -> None:
original_test = """import numpy as np
import unittest
from code_to_optimize.bubble_sort import sorter
class TestPigLatin(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1, 0]
np.testing.assert_equal(sorter(input), [0, 1, 2, 3, 4, 5])
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
output = sorter(input)
np.testing.assert_equal(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
input = list(reversed(range(5000)))
output = sorter(input)
np.testing.assert_equal(output, np.array(list(range(5000))))
np.testing.assert_equal(output, sorter(input))
with open("file.txt", "w") as f:
np.testing.assert_equal(sorter(input), np.array(list(range(5000))))
"""
expected = """from code_to_optimize.bubble_sort import sorter
import numpy as np
import unittest
class TestPigLatin(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1, 0]
codeflash_output = sorter(input)
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
codeflash_output = sorter(input); output = codeflash_output
input = list(reversed(range(5000)))
codeflash_output = sorter(input); output = codeflash_output
codeflash_output = sorter(input)
with open("file.txt", "w") as f:
codeflash_output = sorter(input)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
"""
function_to_optimize = FunctionToOptimize(
function_name="sorter", file_path="bubble_sort.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="code_to_optimize.bubble_sort",
).code
assert result == expected
original_test = """import numpy
import unittest
from code_to_optimize.bubble_sort import sorter
class TestPigLatin(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1, 0]
numpy.testing.assert_equal(sorter(input), [0, 1, 2, 3, 4, 5])
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
output = sorter(input)
numpy.testing.assert_equal(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
input = list(reversed(range(5000)))
output = sorter(input)
numpy.testing.assert_equal(output, np.array(list(range(5000))))
numpy.testing.assert_equal(output, sorter(input))
with open("file.txt", "w") as f:
numpy.testing.assert_equal(sorter(input), np.array(list(range(5000))))
"""
expected = """from code_to_optimize.bubble_sort import sorter
import numpy
import unittest
class TestPigLatin(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1, 0]
codeflash_output = sorter(input)
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
codeflash_output = sorter(input); output = codeflash_output
input = list(reversed(range(5000)))
codeflash_output = sorter(input); output = codeflash_output
codeflash_output = sorter(input)
with open("file.txt", "w") as f:
codeflash_output = sorter(input)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
"""
function_to_optimize = FunctionToOptimize(
function_name="sorter", file_path="bubble_sort.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="code_to_optimize.bubble_sort",
).code
assert result == expected
original_test = """import pandas as pd
import unittest
from code_to_optimize.bubble_sort import sorter
class TestPigLatin(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1, 0]
pd.testing.assert_series_equal(sorter(input), [0, 1, 2, 3, 4, 5])
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
output = sorter(input)
pd.testing.assert_series_equal(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
input = list(reversed(range(5000)))
output = sorter(input)
pd.testing.assert_series_equal(output, np.array(list(range(5000))))
pd.testing.assert_series_equal(output, sorter(input))
with open("file.txt", "w") as f:
pd.testing.assert_series_equal(sorter(input), np.array(list(range(5000))))
"""
expected = """from code_to_optimize.bubble_sort import sorter
import pandas as pd
import unittest
class TestPigLatin(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1, 0]
codeflash_output = sorter(input)
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
codeflash_output = sorter(input); output = codeflash_output
input = list(reversed(range(5000)))
codeflash_output = sorter(input); output = codeflash_output
codeflash_output = sorter(input)
with open("file.txt", "w") as f:
codeflash_output = sorter(input)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
"""
function_to_optimize = FunctionToOptimize(
function_name="sorter", file_path="bubble_sort.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="code_to_optimize.bubble_sort",
).code
assert result == expected
original_test = """import pandas
import unittest
from code_to_optimize.bubble_sort import sorter
class TestPigLatin(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1, 0]
pandas.testing.assert_series_equal(sorter(input), [0, 1, 2, 3, 4, 5])
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
output = sorter(input)
pandas.testing.assert_series_equal(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
input = list(reversed(range(5000)))
output = sorter(input)
pandas.testing.assert_series_equal(output, np.array(list(range(5000))))
pandas.testing.assert_series_equal(output, sorter(input))
with open("file.txt", "w") as f:
pandas.testing.assert_series_equal(sorter(input), np.array(list(range(5000))))
"""
expected = """from code_to_optimize.bubble_sort import sorter
import pandas
import unittest
class TestPigLatin(unittest.TestCase):
def test_sort(self):
input = [5, 4, 3, 2, 1, 0]
codeflash_output = sorter(input)
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
codeflash_output = sorter(input); output = codeflash_output
input = list(reversed(range(5000)))
codeflash_output = sorter(input); output = codeflash_output
codeflash_output = sorter(input)
with open("file.txt", "w") as f:
codeflash_output = sorter(input)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
"""
function_to_optimize = FunctionToOptimize(
function_name="sorter", file_path="bubble_sort.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="code_to_optimize.bubble_sort",
).code
assert result == expected
def test_remove_asserts_empty_simple_statement_line_removed() -> None:
"""Test that a SimpleStatementLine with all statements removed is properly removed.
This tests the fix for the bug where an empty SimpleStatementLine would create
a malformed CST that could cause 'NoneType' object has no attribute 'visit' errors.
When all asserts on a line are irrelevant (not related to the function being optimized),
the entire line should be removed rather than leaving an empty body.
"""
original_test = """def test_multiple_irrelevant_asserts():
x = 5
assert 1 == 1; assert True; assert 2 > 1
y = 10
"""
expected = """from some_file import some_function
def test_multiple_irrelevant_asserts():
x = 5
y = 10
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
"""
function_to_optimize = FunctionToOptimize(
function_name="some_function", file_path="some_file.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="some_file",
).code
assert result == expected
def test_remove_asserts_one_line_function_all_removed() -> None:
"""Test that one-line functions with only irrelevant asserts are removed entirely.
When a one-line function body (like `def test(): assert x`) has all its statements
removed (because they're irrelevant asserts), the entire function should be removed
rather than leaving a useless `def test(): pass` stub.
"""
original_test = """def test_irrelevant(): assert 1 == 1"""
expected = """from some_file import some_function
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."""
function_to_optimize = FunctionToOptimize(
function_name="some_function", file_path="some_file.py", parents=[], starting_line=None, ending_line=None
)
result = remove_asserts_from_test(
module=parse_module_to_cst(original_test),
function_to_optimize=function_to_optimize,
helper_function_names=[],
module_path="some_file",
).code
assert result == expected
def test_handle_statement_with_empty_body() -> None:
"""Test that handle_statement properly handles nodes with empty body.
This tests the fix for the bug where a SimpleStatementLine with an empty body
(from prior transformations or malformed input) would be returned unchanged,
causing CST tree corruption and 'NoneType' object has no attribute 'visit' errors
in subsequent transforms.
- SimpleStatementLine with empty body should return RemoveFromParent()
- SimpleStatementSuite with empty body should return a node with Pass()
"""
function_to_optimize = FunctionToOptimize(
function_name="some_function", file_path="some_file.py", parents=[], starting_line=None, ending_line=None
)
transformer = RemoveAssertTransformer(function_to_optimize, [])
handler = StatementHandler(transformer)
# Test SimpleStatementLine with empty body returns RemoveFromParent()
empty_line = SimpleStatementLine(body=[])
result = handler.handle_statement(empty_line)
assert result == RemoveFromParent()
# Test SimpleStatementSuite with empty body returns node with Pass()
empty_suite = SimpleStatementSuite(body=[])
result = handler.handle_statement(empty_suite)
assert isinstance(result, SimpleStatementSuite)
assert len(result.body) == 1
assert isinstance(result.body[0], Pass)