diff --git a/code_to_optimize/bubble_sort.py b/code_to_optimize/bubble_sort.py index db7db5f92..9e97f63a0 100644 --- a/code_to_optimize/bubble_sort.py +++ b/code_to_optimize/bubble_sort.py @@ -1,8 +1,10 @@ def sorter(arr): + print("codeflash stdout: Sorting list") for i in range(len(arr)): for j in range(len(arr) - 1): if arr[j] > arr[j + 1]: temp = arr[j] arr[j] = arr[j + 1] arr[j + 1] = temp + print(f"result: {arr}") return arr diff --git a/code_to_optimize/process_and_bubble_sort_codeflash_trace.py b/code_to_optimize/process_and_bubble_sort_codeflash_trace.py new file mode 100644 index 000000000..37c2abab8 --- /dev/null +++ b/code_to_optimize/process_and_bubble_sort_codeflash_trace.py @@ -0,0 +1,28 @@ +from code_to_optimize.bubble_sort import sorter +from codeflash.benchmarking.codeflash_trace import codeflash_trace + +def calculate_pairwise_products(arr): + """ + Calculate the average of all pairwise products in the array. + """ + sum_of_products = 0 + count = 0 + + for i in range(len(arr)): + for j in range(len(arr)): + if i != j: + sum_of_products += arr[i] * arr[j] + count += 1 + + # The average of all pairwise products + return sum_of_products / count if count > 0 else 0 + +@codeflash_trace +def compute_and_sort(arr): + # Compute pairwise sums average + pairwise_average = calculate_pairwise_products(arr) + + # Call sorter function + sorter(arr.copy()) + + return pairwise_average diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index fb80541aa..cd0bfc50a 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -363,23 +363,25 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor): for decorator in body_node.decorator_list ): self.is_staticmethod = True + print(f"static method found: {self.function_name}") + return + elif self.line_no: + # If we have line number info, check if class has a static method with the same line number + # This way, if we don't have the class name, we can still find the static method + for body_node in node.body: + if ( + isinstance(body_node, ast.FunctionDef) + and body_node.name == self.function_name + and body_node.lineno in {self.line_no, self.line_no + 1} + and any( + isinstance(decorator, ast.Name) and decorator.id == "staticmethod" + for decorator in body_node.decorator_list + ) + ): + self.is_staticmethod = True + self.is_top_level = True + self.class_name = node.name return - # else: - # # search if the class has a staticmethod with the same name and on the same line number - # for body_node in node.body: - # if ( - # isinstance(body_node, ast.FunctionDef) - # and body_node.name == self.function_name - # # and body_node.lineno in {self.line_no, self.line_no + 1} - # and any( - # isinstance(decorator, ast.Name) and decorator.id == "staticmethod" - # for decorator in body_node.decorator_list - # ) - # ): - # self.is_staticmethod = True - # self.is_top_level = True - # self.class_name = node.name - # return return diff --git a/codeflash/discovery/pytest_new_process_discovery.py b/codeflash/discovery/pytest_new_process_discovery.py index 2d8583255..d5a80f501 100644 --- a/codeflash/discovery/pytest_new_process_discovery.py +++ b/codeflash/discovery/pytest_new_process_discovery.py @@ -16,12 +16,6 @@ class PytestCollectionPlugin: collected_tests.extend(session.items) pytest_rootdir = session.config.rootdir - def pytest_collection_modifyitems(config, items): - skip_benchmark = pytest.mark.skip(reason="Skipping benchmark tests") - for item in items: - if "benchmark" in item.fixturenames: - item.add_marker(skip_benchmark) - def parse_pytest_collection_results(pytest_tests: list[Any]) -> list[dict[str, str]]: test_results = [] @@ -40,7 +34,7 @@ if __name__ == "__main__": try: exitcode = pytest.main( - [tests_root, "-p no:logging", "--collect-only", "-m", "not skip"], plugins=[PytestCollectionPlugin()] + [tests_root, "-pno:logging", "--collect-only", "-m", "not skip", "--benchmark-skip"], plugins=[PytestCollectionPlugin()] ) except Exception as e: # noqa: BLE001 print(f"Failed to collect tests: {e!s}") # noqa: T201 diff --git a/tests/test_unit_test_discovery.py b/tests/test_unit_test_discovery.py index 4bf99c049..c05b79e63 100644 --- a/tests/test_unit_test_discovery.py +++ b/tests/test_unit_test_discovery.py @@ -18,11 +18,11 @@ def test_unit_test_discovery_pytest(): ) tests = discover_unit_tests(test_config) assert len(tests) > 0 - # print(tests) + def test_benchmark_test_discovery_pytest(): project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize" - tests_path = project_path / "tests" / "pytest" / "benchmarks" / "test_benchmark_bubble_sort.py" + tests_path = project_path / "tests" / "pytest" / "benchmarks" test_config = TestConfig( tests_root=tests_path, project_root_path=project_path, @@ -30,10 +30,7 @@ def test_benchmark_test_discovery_pytest(): tests_project_rootdir=tests_path.parent, ) tests = discover_unit_tests(test_config) - assert len(tests) > 0 - assert 'bubble_sort.sorter' in tests - benchmark_tests = sum(1 for test in tests['bubble_sort.sorter'] if test.tests_in_file.test_type == TestType.BENCHMARK_TEST) - assert benchmark_tests == 1 + assert len(tests) == 1 # Should not discover benchmark tests def test_unit_test_discovery_unittest():