tests pass

This commit is contained in:
Alvin Ryanputra 2025-03-19 15:59:13 -07:00
parent 54fe71f336
commit 5fd112a2f9
5 changed files with 52 additions and 29 deletions

View file

@ -1,8 +1,10 @@
def sorter(arr):
print("codeflash stdout: Sorting list")
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
print(f"result: {arr}")
return arr

View file

@ -0,0 +1,28 @@
from code_to_optimize.bubble_sort import sorter
from codeflash.benchmarking.codeflash_trace import codeflash_trace
def calculate_pairwise_products(arr):
"""
Calculate the average of all pairwise products in the array.
"""
sum_of_products = 0
count = 0
for i in range(len(arr)):
for j in range(len(arr)):
if i != j:
sum_of_products += arr[i] * arr[j]
count += 1
# The average of all pairwise products
return sum_of_products / count if count > 0 else 0
@codeflash_trace
def compute_and_sort(arr):
# Compute pairwise sums average
pairwise_average = calculate_pairwise_products(arr)
# Call sorter function
sorter(arr.copy())
return pairwise_average

View file

@ -363,23 +363,25 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
for decorator in body_node.decorator_list
):
self.is_staticmethod = True
print(f"static method found: {self.function_name}")
return
elif self.line_no:
# If we have line number info, check if class has a static method with the same line number
# This way, if we don't have the class name, we can still find the static method
for body_node in node.body:
if (
isinstance(body_node, ast.FunctionDef)
and body_node.name == self.function_name
and body_node.lineno in {self.line_no, self.line_no + 1}
and any(
isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
for decorator in body_node.decorator_list
)
):
self.is_staticmethod = True
self.is_top_level = True
self.class_name = node.name
return
# else:
# # search if the class has a staticmethod with the same name and on the same line number
# for body_node in node.body:
# if (
# isinstance(body_node, ast.FunctionDef)
# and body_node.name == self.function_name
# # and body_node.lineno in {self.line_no, self.line_no + 1}
# and any(
# isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
# for decorator in body_node.decorator_list
# )
# ):
# self.is_staticmethod = True
# self.is_top_level = True
# self.class_name = node.name
# return
return

View file

@ -16,12 +16,6 @@ class PytestCollectionPlugin:
collected_tests.extend(session.items)
pytest_rootdir = session.config.rootdir
def pytest_collection_modifyitems(config, items):
skip_benchmark = pytest.mark.skip(reason="Skipping benchmark tests")
for item in items:
if "benchmark" in item.fixturenames:
item.add_marker(skip_benchmark)
def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]:
test_results = []
@ -40,7 +34,7 @@ if __name__ == "__main__":
try:
exitcode = pytest.main(
[tests_root, "-p no:logging", "--collect-only", "-m", "not skip"], plugins=[PytestCollectionPlugin()]
[tests_root, "-pno:logging", "--collect-only", "-m", "not skip", "--benchmark-skip"], plugins=[PytestCollectionPlugin()]
)
except Exception as e: # noqa: BLE001
print(f"Failed to collect tests: {e!s}") # noqa: T201

View file

@ -18,11 +18,11 @@ def test_unit_test_discovery_pytest():
)
tests = discover_unit_tests(test_config)
assert len(tests) > 0
# print(tests)
def test_benchmark_test_discovery_pytest():
project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize"
tests_path = project_path / "tests" / "pytest" / "benchmarks" / "test_benchmark_bubble_sort.py"
tests_path = project_path / "tests" / "pytest" / "benchmarks"
test_config = TestConfig(
tests_root=tests_path,
project_root_path=project_path,
@ -30,10 +30,7 @@ def test_benchmark_test_discovery_pytest():
tests_project_rootdir=tests_path.parent,
)
tests = discover_unit_tests(test_config)
assert len(tests) > 0
assert 'bubble_sort.sorter' in tests
benchmark_tests = sum(1 for test in tests['bubble_sort.sorter'] if test.tests_in_file.test_type == TestType.BENCHMARK_TEST)
assert benchmark_tests == 1
assert len(tests) == 1 # Should not discover benchmark tests
def test_unit_test_discovery_unittest():