tests pass
This commit is contained in:
parent
54fe71f336
commit
5fd112a2f9
5 changed files with 52 additions and 29 deletions
|
|
@ -1,8 +1,10 @@
|
|||
def sorter(arr):
|
||||
print("codeflash stdout: Sorting list")
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr) - 1):
|
||||
if arr[j] > arr[j + 1]:
|
||||
temp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
print(f"result: {arr}")
|
||||
return arr
|
||||
|
|
|
|||
28
code_to_optimize/process_and_bubble_sort_codeflash_trace.py
Normal file
28
code_to_optimize/process_and_bubble_sort_codeflash_trace.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from code_to_optimize.bubble_sort import sorter
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
||||
def calculate_pairwise_products(arr):
|
||||
"""
|
||||
Calculate the average of all pairwise products in the array.
|
||||
"""
|
||||
sum_of_products = 0
|
||||
count = 0
|
||||
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr)):
|
||||
if i != j:
|
||||
sum_of_products += arr[i] * arr[j]
|
||||
count += 1
|
||||
|
||||
# The average of all pairwise products
|
||||
return sum_of_products / count if count > 0 else 0
|
||||
|
||||
@codeflash_trace
|
||||
def compute_and_sort(arr):
|
||||
# Compute pairwise sums average
|
||||
pairwise_average = calculate_pairwise_products(arr)
|
||||
|
||||
# Call sorter function
|
||||
sorter(arr.copy())
|
||||
|
||||
return pairwise_average
|
||||
|
|
@ -363,23 +363,25 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
|||
for decorator in body_node.decorator_list
|
||||
):
|
||||
self.is_staticmethod = True
|
||||
print(f"static method found: {self.function_name}")
|
||||
return
|
||||
elif self.line_no:
|
||||
# If we have line number info, check if class has a static method with the same line number
|
||||
# This way, if we don't have the class name, we can still find the static method
|
||||
for body_node in node.body:
|
||||
if (
|
||||
isinstance(body_node, ast.FunctionDef)
|
||||
and body_node.name == self.function_name
|
||||
and body_node.lineno in {self.line_no, self.line_no + 1}
|
||||
and any(
|
||||
isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
|
||||
for decorator in body_node.decorator_list
|
||||
)
|
||||
):
|
||||
self.is_staticmethod = True
|
||||
self.is_top_level = True
|
||||
self.class_name = node.name
|
||||
return
|
||||
# else:
|
||||
# # search if the class has a staticmethod with the same name and on the same line number
|
||||
# for body_node in node.body:
|
||||
# if (
|
||||
# isinstance(body_node, ast.FunctionDef)
|
||||
# and body_node.name == self.function_name
|
||||
# # and body_node.lineno in {self.line_no, self.line_no + 1}
|
||||
# and any(
|
||||
# isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
|
||||
# for decorator in body_node.decorator_list
|
||||
# )
|
||||
# ):
|
||||
# self.is_staticmethod = True
|
||||
# self.is_top_level = True
|
||||
# self.class_name = node.name
|
||||
# return
|
||||
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -16,12 +16,6 @@ class PytestCollectionPlugin:
|
|||
collected_tests.extend(session.items)
|
||||
pytest_rootdir = session.config.rootdir
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
skip_benchmark = pytest.mark.skip(reason="Skipping benchmark tests")
|
||||
for item in items:
|
||||
if "benchmark" in item.fixturenames:
|
||||
item.add_marker(skip_benchmark)
|
||||
|
||||
|
||||
def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]:
|
||||
test_results = []
|
||||
|
|
@ -40,7 +34,7 @@ if __name__ == "__main__":
|
|||
|
||||
try:
|
||||
exitcode = pytest.main(
|
||||
[tests_root, "-p no:logging", "--collect-only", "-m", "not skip"], plugins=[PytestCollectionPlugin()]
|
||||
[tests_root, "-pno:logging", "--collect-only", "-m", "not skip", "--benchmark-skip"], plugins=[PytestCollectionPlugin()]
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
print(f"Failed to collect tests: {e!s}") # noqa: T201
|
||||
|
|
|
|||
|
|
@ -18,11 +18,11 @@ def test_unit_test_discovery_pytest():
|
|||
)
|
||||
tests = discover_unit_tests(test_config)
|
||||
assert len(tests) > 0
|
||||
# print(tests)
|
||||
|
||||
|
||||
def test_benchmark_test_discovery_pytest():
|
||||
project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize"
|
||||
tests_path = project_path / "tests" / "pytest" / "benchmarks" / "test_benchmark_bubble_sort.py"
|
||||
tests_path = project_path / "tests" / "pytest" / "benchmarks"
|
||||
test_config = TestConfig(
|
||||
tests_root=tests_path,
|
||||
project_root_path=project_path,
|
||||
|
|
@ -30,10 +30,7 @@ def test_benchmark_test_discovery_pytest():
|
|||
tests_project_rootdir=tests_path.parent,
|
||||
)
|
||||
tests = discover_unit_tests(test_config)
|
||||
assert len(tests) > 0
|
||||
assert 'bubble_sort.sorter' in tests
|
||||
benchmark_tests = sum(1 for test in tests['bubble_sort.sorter'] if test.tests_in_file.test_type == TestType.BENCHMARK_TEST)
|
||||
assert benchmark_tests == 1
|
||||
assert len(tests) == 1 # Should not discover benchmark tests
|
||||
|
||||
|
||||
def test_unit_test_discovery_unittest():
|
||||
|
|
|
|||
Loading…
Reference in a new issue