diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 466d8f70c..006ed63cf 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -715,7 +715,12 @@ def inject_profiling_into_existing_test( from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test return inject_profiling_into_existing_js_test( - test_string=test_string, call_positions=call_positions, function_to_optimize=function_to_optimize, tests_project_root=tests_project_root, mode= mode.value, test_path=test_path + test_string=test_string, + call_positions=call_positions, + function_to_optimize=function_to_optimize, + tests_project_root=tests_project_root, + mode=mode.value, + test_path=test_path, ) if is_java(): @@ -725,11 +730,14 @@ def inject_profiling_into_existing_test( if function_to_optimize.is_async: return inject_async_profiling_into_existing_test( - test_string=test_string, call_positions=call_positions, function_to_optimize=function_to_optimize, tests_project_root=tests_project_root, mode=mode.value, test_path=test_path + test_string=test_string, + call_positions=call_positions, + function_to_optimize=function_to_optimize, + tests_project_root=tests_project_root, + mode=mode.value, + test_path=test_path, ) - - used_frameworks = detect_frameworks_from_code(test_string) try: tree = ast.parse(test_string) diff --git a/codeflash/languages/base.py b/codeflash/languages/base.py index 224ee6cdb..d1cb357e7 100644 --- a/codeflash/languages/base.py +++ b/codeflash/languages/base.py @@ -572,7 +572,7 @@ class LanguageSupport(Protocol): function_to_optimize: Any, tests_project_root: Path, mode: str, - test_path: Path | None + test_path: Path | None, ) -> tuple[bool, str | None]: """Inject profiling code into an existing test file. diff --git a/codeflash/languages/java/build_tools.py b/codeflash/languages/java/build_tools.py index 5e218587e..4460a6d9e 100644 --- a/codeflash/languages/java/build_tools.py +++ b/codeflash/languages/java/build_tools.py @@ -184,7 +184,6 @@ def _get_maven_project_info(project_root: Path) -> JavaProjectInfo | None: if test_src.exists(): test_roots.append(test_src) - # Check for custom source directories in pom.xml section for build in [root.find("m:build", ns), root.find("build")]: if build is not None: @@ -660,7 +659,7 @@ def add_codeflash_dependency_to_pom(pom_path: Path) -> bool: new_content = content[:idx] + CODEFLASH_DEPENDENCY_SNIPPET # Skip the original tag since our snippet includes it - new_content += content[idx + len(closing_tag):] + new_content += content[idx + len(closing_tag) :] pom_path.write_text(new_content, encoding="utf-8") logger.info("Added codeflash-runtime dependency to pom.xml") diff --git a/codeflash/languages/java/comparator.py b/codeflash/languages/java/comparator.py index 3deb9c692..baa1cd042 100644 --- a/codeflash/languages/java/comparator.py +++ b/codeflash/languages/java/comparator.py @@ -179,13 +179,20 @@ def compare_test_results( [ java_exe, # Java 16+ module system: Kryo needs reflective access to internal JDK classes - "--add-opens", "java.base/java.util=ALL-UNNAMED", - "--add-opens", "java.base/java.lang=ALL-UNNAMED", - "--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED", - "--add-opens", "java.base/java.io=ALL-UNNAMED", - "--add-opens", "java.base/java.math=ALL-UNNAMED", - "--add-opens", "java.base/java.net=ALL-UNNAMED", - "--add-opens", "java.base/java.util.zip=ALL-UNNAMED", + "--add-opens", + "java.base/java.util=ALL-UNNAMED", + "--add-opens", + "java.base/java.lang=ALL-UNNAMED", + "--add-opens", + "java.base/java.lang.reflect=ALL-UNNAMED", + "--add-opens", + "java.base/java.io=ALL-UNNAMED", + "--add-opens", + "java.base/java.math=ALL-UNNAMED", + "--add-opens", + "java.base/java.net=ALL-UNNAMED", + "--add-opens", + "java.base/java.util.zip=ALL-UNNAMED", "-cp", str(jar_path), "com.codeflash.Comparator", diff --git a/codeflash/languages/java/concurrency_analyzer.py b/codeflash/languages/java/concurrency_analyzer.py index 90a7aaa56..0fde83a1a 100644 --- a/codeflash/languages/java/concurrency_analyzer.py +++ b/codeflash/languages/java/concurrency_analyzer.py @@ -18,8 +18,6 @@ from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: - from tree_sitter import Node - from codeflash.languages.base import FunctionInfo logger = logging.getLogger(__name__) @@ -306,9 +304,7 @@ class JavaConcurrencyAnalyzer: return suggestions -def analyze_function_concurrency( - func: FunctionInfo, source: str | None = None, analyzer=None -) -> ConcurrencyInfo: +def analyze_function_concurrency(func: FunctionInfo, source: str | None = None, analyzer=None) -> ConcurrencyInfo: """Analyze a function for concurrency patterns. Convenience function that creates a JavaConcurrencyAnalyzer and analyzes the function. diff --git a/codeflash/languages/java/line_profiler.py b/codeflash/languages/java/line_profiler.py index 314d3dad9..527a3ab2c 100644 --- a/codeflash/languages/java/line_profiler.py +++ b/codeflash/languages/java/line_profiler.py @@ -34,6 +34,7 @@ class JavaLineProfiler: instrumented = profiler.instrument_source(source, file_path, functions) # Run instrumented code results = JavaLineProfiler.parse_results(Path("profile.json")) + """ def __init__(self, output_file: Path) -> None: @@ -48,13 +49,7 @@ class JavaLineProfiler: self.profiler_var = "__codeflashProfiler__" self.line_contents: dict[str, str] = {} - def instrument_source( - self, - source: str, - file_path: Path, - functions: list[FunctionInfo], - analyzer=None, - ) -> str: + def instrument_source(self, source: str, file_path: Path, functions: list[FunctionInfo], analyzer=None) -> str: """Instrument Java source code with line profiling. Adds profiling instrumentation to track line-level execution for the @@ -106,9 +101,7 @@ class JavaLineProfiler: import_end_idx = i break - lines_with_profiler = ( - lines[:import_end_idx] + [profiler_class_code + "\n"] + lines[import_end_idx:] - ) + lines_with_profiler = lines[:import_end_idx] + [profiler_class_code + "\n"] + lines[import_end_idx:] result = "".join(lines_with_profiler) if not analyzer.validate_syntax(result): @@ -121,7 +114,7 @@ class JavaLineProfiler: # Store line contents as a simple map (embedded directly in code) line_contents_code = self._generate_line_contents_map() - return f''' + return f""" /** * Codeflash line profiler - tracks per-line execution statistics. * Auto-generated - do not modify. @@ -132,7 +125,7 @@ class {self.profiler_class} {{ private static final ThreadLocal lastLineTime = new ThreadLocal<>(); private static final ThreadLocal lastKey = new ThreadLocal<>(); private static final java.util.concurrent.atomic.AtomicInteger totalHits = new java.util.concurrent.atomic.AtomicInteger(0); - private static final String OUTPUT_FILE = "{str(self.output_file)}"; + private static final String OUTPUT_FILE = "{self.output_file!s}"; static class LineStats {{ public final java.util.concurrent.atomic.AtomicLong hits = new java.util.concurrent.atomic.AtomicLong(0); @@ -247,15 +240,9 @@ class {self.profiler_class} {{ Runtime.getRuntime().addShutdownHook(new Thread(() -> save())); }} }} -''' +""" - def _instrument_function( - self, - func: FunctionInfo, - lines: list[str], - file_path: Path, - analyzer, - ) -> list[str]: + def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path: Path, analyzer) -> list[str]: """Instrument a single function with line profiling. Args: @@ -300,9 +287,7 @@ class {self.profiler_class} {{ # Add the line with enterFunction() call after it instrumented_lines.append(line) - instrumented_lines.append( - f"{body_indent}{self.profiler_class}.enterFunction();\n" - ) + instrumented_lines.append(f"{body_indent}{self.profiler_class}.enterFunction();\n") function_entry_added = True continue @@ -326,8 +311,7 @@ class {self.profiler_class} {{ # Add hit() call before the line profiled_line = ( - f"{indent_str}{self.profiler_class}.hit(" - f'"{file_path.as_posix()}", {global_line_num});\n{line}' + f'{indent_str}{self.profiler_class}.hit("{file_path.as_posix()}", {global_line_num});\n{line}' ) instrumented_lines.append(profiled_line) else: @@ -497,8 +481,6 @@ def format_line_profile_results(results: dict, file_path: Path | None = None) -> avg_ms = time_ms / hits if hits > 0 else 0 content = stats.get("content", "")[:50] # Truncate long lines - output.append( - f"{line_num:6d} | {hits:10d} | {time_ms:12.3f} | {avg_ms:12.6f} | {content}" - ) + output.append(f"{line_num:6d} | {hits:10d} | {time_ms:12.3f} | {avg_ms:12.6f} | {content}") return "\n".join(output) diff --git a/codeflash/languages/java/remove_asserts.py b/codeflash/languages/java/remove_asserts.py index 5bb86de5b..56160f67b 100644 --- a/codeflash/languages/java/remove_asserts.py +++ b/codeflash/languages/java/remove_asserts.py @@ -298,9 +298,7 @@ class JavaAssertTransformer: # - Assertions.assertEquals (JUnit 5) # - org.junit.jupiter.api.Assertions.assertEquals (fully qualified) all_assertions = "|".join(JUNIT5_ALL_ASSERTIONS) - pattern = re.compile( - rf"(\s*)((?:(?:\w+\.)*Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE - ) + pattern = re.compile(rf"(\s*)((?:(?:\w+\.)*Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE) for match in pattern.finditer(source): leading_ws = match.group(1) @@ -549,8 +547,12 @@ class JavaAssertTransformer: return results def _collect_target_invocations( - self, node, wrapper_bytes: bytes, content_bytes: bytes, - base_offset: int, out: list[TargetCall], + self, + node, + wrapper_bytes: bytes, + content_bytes: bytes, + base_offset: int, + out: list[TargetCall], seen_top_level: set[tuple[int, int]] | None = None, ) -> None: """Recursively walk the AST and collect method_invocation nodes that match self.func_name. @@ -574,22 +576,24 @@ class JavaAssertTransformer: seen_top_level.add(range_key) start = top_node.start_byte - prefix_len end = top_node.end_byte - prefix_len - if 0 <= start and end <= len(content_bytes): + if start >= 0 and end <= len(content_bytes): full_call = self.analyzer.get_node_text(top_node, wrapper_bytes) start_char = len(content_bytes[:start].decode("utf8")) end_char = len(content_bytes[:end].decode("utf8")) - out.append(TargetCall( - receiver=None, - method_name=self.func_name, - arguments="", - full_call=full_call, - start_pos=base_offset + start_char, - end_pos=base_offset + end_char, - )) + out.append( + TargetCall( + receiver=None, + method_name=self.func_name, + arguments="", + full_call=full_call, + start_pos=base_offset + start_char, + end_pos=base_offset + end_char, + ) + ) else: start = node.start_byte - prefix_len end = node.end_byte - prefix_len - if 0 <= start and end <= len(content_bytes): + if start >= 0 and end <= len(content_bytes): out.append(self._build_target_call(node, wrapper_bytes, content_bytes, start, end, base_offset)) return @@ -597,8 +601,7 @@ class JavaAssertTransformer: self._collect_target_invocations(child, wrapper_bytes, content_bytes, base_offset, out, seen_top_level) def _build_target_call( - self, node, wrapper_bytes: bytes, content_bytes: bytes, - start_byte: int, end_byte: int, base_offset: int, + self, node, wrapper_bytes: bytes, content_bytes: bytes, start_byte: int, end_byte: int, base_offset: int ) -> TargetCall: """Build a TargetCall from a tree-sitter method_invocation node.""" get_text = self.analyzer.get_node_text @@ -679,7 +682,6 @@ class JavaAssertTransformer: # Handle generic types: Type varName = ... match = self._assign_re.search(source, line_start, assertion_start) - if match: var_type = match.group(1).strip() var_name = match.group(2).strip() @@ -934,18 +936,12 @@ class JavaAssertTransformer: f"catch (Exception _cf_ignored{counter}) {{}}" ) - return ( - f"{ws}try {{ {code_to_run} }} " - f"catch (Exception _cf_ignored{counter}) {{}}" - ) + return f"{ws}try {{ {code_to_run} }} catch (Exception _cf_ignored{counter}) {{}}" # If no lambda body found, try to extract from target calls if assertion.target_calls: call = assertion.target_calls[0] - return ( - f"{ws}try {{ {call.full_call}; }} " - f"catch (Exception _cf_ignored{counter}) {{}}" - ) + return f"{ws}try {{ {call.full_call}; }} catch (Exception _cf_ignored{counter}) {{}}" # Fallback: comment out the assertion return f"{ws}// Removed assertThrows: could not extract callable" diff --git a/codeflash/languages/java/support.py b/codeflash/languages/java/support.py index e33e98dcf..d9ae798fe 100644 --- a/codeflash/languages/java/support.py +++ b/codeflash/languages/java/support.py @@ -12,11 +12,8 @@ from typing import TYPE_CHECKING, Any from codeflash.languages.base import Language, LanguageSupport from codeflash.languages.java.build_tools import find_test_root from codeflash.languages.java.comparator import compare_test_results as _compare_test_results +from codeflash.languages.java.concurrency_analyzer import analyze_function_concurrency from codeflash.languages.java.config import detect_java_project -from codeflash.languages.java.concurrency_analyzer import ( - JavaConcurrencyAnalyzer, - analyze_function_concurrency, -) from codeflash.languages.java.context import extract_code_context, find_helper_functions from codeflash.languages.java.discovery import discover_functions, discover_functions_from_source from codeflash.languages.java.formatter import format_java_code, normalize_java_code @@ -288,14 +285,11 @@ class JavaSupport(LanguageSupport): function_to_optimize: Any, tests_project_root: Path, mode: str, - test_path: Path | None + test_path: Path | None, ) -> tuple[bool, str | None]: """Inject profiling code into an existing test file.""" return instrument_existing_test( - test_string=test_string, - function_to_optimize=function_to_optimize, - mode=mode, - test_path=test_path + test_string=test_string, function_to_optimize=function_to_optimize, mode=mode, test_path=test_path ) def instrument_source_for_line_profiler( diff --git a/codeflash/languages/java/test_runner.py b/codeflash/languages/java/test_runner.py index cd5aa488a..53084c932 100644 --- a/codeflash/languages/java/test_runner.py +++ b/codeflash/languages/java/test_runner.py @@ -610,13 +610,20 @@ def _run_tests_direct( cmd = [ str(java), # Java 16+ module system: Kryo needs reflective access to internal JDK classes - "--add-opens", "java.base/java.util=ALL-UNNAMED", - "--add-opens", "java.base/java.lang=ALL-UNNAMED", - "--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED", - "--add-opens", "java.base/java.io=ALL-UNNAMED", - "--add-opens", "java.base/java.math=ALL-UNNAMED", - "--add-opens", "java.base/java.net=ALL-UNNAMED", - "--add-opens", "java.base/java.util.zip=ALL-UNNAMED", + "--add-opens", + "java.base/java.util=ALL-UNNAMED", + "--add-opens", + "java.base/java.lang=ALL-UNNAMED", + "--add-opens", + "java.base/java.lang.reflect=ALL-UNNAMED", + "--add-opens", + "java.base/java.io=ALL-UNNAMED", + "--add-opens", + "java.base/java.math=ALL-UNNAMED", + "--add-opens", + "java.base/java.net=ALL-UNNAMED", + "--add-opens", + "java.base/java.util.zip=ALL-UNNAMED", "-cp", classpath, "org.junit.platform.console.ConsoleLauncher", diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index 10d3b96d9..149e2bcd7 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -1941,7 +1941,7 @@ class JavaScriptSupport: function_to_optimize: Any, tests_project_root: Path, mode: str, - test_path: Path|None, + test_path: Path | None, ) -> tuple[bool, str | None]: """Inject profiling code into an existing JavaScript test file. diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 601169dd2..bc5d77f13 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -800,7 +800,9 @@ class FunctionOptimizer: logger.debug( f"[JAVA] Detected Java sources root: {java_sources_root} (from tests_root: {tests_root})" ) - logger.debug(f"[JAVA-ROOT] Returning Java sources root: {java_sources_root}, tests_root was: {tests_root}") + logger.debug( + f"[JAVA-ROOT] Returning Java sources root: {java_sources_root}, tests_root was: {tests_root}" + ) return java_sources_root # If no standard package prefix found, check if there's a 'java' directory @@ -810,7 +812,9 @@ class FunctionOptimizer: # Return up to and including 'java' java_sources_root = Path(*parts[: i + 1]) logger.debug(f"[JAVA] Detected Maven-style Java sources root: {java_sources_root}") - logger.debug(f"[JAVA-ROOT] Returning Java sources root: {java_sources_root}, tests_root was: {tests_root}") + logger.debug( + f"[JAVA-ROOT] Returning Java sources root: {java_sources_root}, tests_root was: {tests_root}" + ) return java_sources_root # Default: return tests_root as-is (original behavior) @@ -862,7 +866,7 @@ class FunctionOptimizer: if main_match: main_module_name = main_match.group(1) if package_name.startswith(main_module_name): - suffix = package_name[len(main_module_name):] + suffix = package_name[len(main_module_name) :] new_package = test_module_name + suffix old_decl = f"package {package_name};" new_decl = f"package {new_package};" diff --git a/codeflash/setup/detector.py b/codeflash/setup/detector.py index ea9c3b858..defe1a22d 100644 --- a/codeflash/setup/detector.py +++ b/codeflash/setup/detector.py @@ -164,7 +164,15 @@ def _find_project_root(start_path: Path) -> Path | None: while current != current.parent: # Check for project markers - markers = [".git", "pyproject.toml", "package.json", "Cargo.toml", "pom.xml", "build.gradle", "build.gradle.kts"] + markers = [ + ".git", + "pyproject.toml", + "package.json", + "Cargo.toml", + "pom.xml", + "build.gradle", + "build.gradle.kts", + ] for marker in markers: if (current / marker).exists(): return current @@ -489,10 +497,17 @@ def _detect_tests_root(project_root: Path, language: str) -> tuple[Path | None, for elem in [build.find("m:testSourceDirectory", ns), build.find("testSourceDirectory")]: if elem is not None and elem.text: # Resolve ${project.basedir}/src -> test_module_dir/src - dir_text = elem.text.strip().replace("${project.basedir}/", "").replace("${project.basedir}", ".") + dir_text = ( + elem.text.strip() + .replace("${project.basedir}/", "") + .replace("${project.basedir}", ".") + ) resolved = test_module_dir / dir_text if resolved.is_dir(): - return resolved, f"{test_module_name}/{dir_text} (from {test_module_name}/pom.xml testSourceDirectory)" + return ( + resolved, + f"{test_module_name}/{dir_text} (from {test_module_name}/pom.xml testSourceDirectory)", + ) except ET.ParseError: pass # Test module exists but no custom testSourceDirectory - use the module root @@ -548,8 +563,6 @@ def _detect_test_runner(project_root: Path, language: str) -> tuple[str, str]: def _detect_java_test_runner(project_root: Path) -> tuple[str, str]: """Detect Java test framework.""" - import xml.etree.ElementTree as ET - pom_path = project_root / "pom.xml" if pom_path.exists(): try: diff --git a/codeflash/verification/coverage_utils.py b/codeflash/verification/coverage_utils.py index c73c7982f..c77f5e7df 100644 --- a/codeflash/verification/coverage_utils.py +++ b/codeflash/verification/coverage_utils.py @@ -231,7 +231,9 @@ class JacocoCoverageUtils: f"File preview: {content_preview!r}" ) except Exception as read_err: - logger.warning(f"Failed to parse JaCoCo XML file at '{jacoco_xml_path}': {e}. Could not read file: {read_err}") + logger.warning( + f"Failed to parse JaCoCo XML file at '{jacoco_xml_path}': {e}. Could not read file: {read_err}" + ) return CoverageData.create_empty(source_code_path, function_name, code_context) # Determine expected source file name from path diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 9a4f7d91e..c9d067458 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -27,9 +27,7 @@ def safe_repr(obj: object) -> str: return f"" -def compare_test_results( - original_results: TestResults, candidate_results: TestResults -) -> tuple[bool, list[TestDiff]]: +def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]: # This is meant to be only called with test results for the first loop index if len(original_results) == 0 or len(candidate_results) == 0: return False, [] # empty test results are not equal @@ -102,9 +100,7 @@ def compare_test_results( ) ) - elif not comparator( - original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj - ): + elif not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj): test_diffs.append( TestDiff( scope=TestDiffScope.RETURN_VALUE, @@ -129,9 +125,8 @@ def compare_test_results( ) except Exception as e: logger.error(e) - elif ( - (original_test_result.stdout and cdd_test_result.stdout) - and not comparator(original_test_result.stdout, cdd_test_result.stdout) + elif (original_test_result.stdout and cdd_test_result.stdout) and not comparator( + original_test_result.stdout, cdd_test_result.stdout ): test_diffs.append( TestDiff( diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 1d8853a7e..d8382320d 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -1002,7 +1002,9 @@ def parse_test_xml( # Always use tests_project_rootdir since pytest is now the test runner for all frameworks base_dir = test_config.tests_project_rootdir logger.debug(f"[PARSE-XML] base_dir for resolution: {base_dir}") - logger.debug(f"[PARSE-XML] Registered test files: {[str(tf.instrumented_behavior_file_path) for tf in test_files.test_files]}") + logger.debug( + f"[PARSE-XML] Registered test files: {[str(tf.instrumented_behavior_file_path) for tf in test_files.test_files]}" + ) # For Java: pre-parse fallback stdout once (not per testcase) to avoid O(n²) complexity java_fallback_stdout = None @@ -1067,7 +1069,9 @@ def parse_test_xml( test_file_path = resolve_test_file_from_class_path(test_class_path, base_dir) if test_file_path is None: - logger.error(f"[PARSE-XML] ERROR: Could not resolve test_class_path={test_class_path}, base_dir={base_dir}") + logger.error( + f"[PARSE-XML] ERROR: Could not resolve test_class_path={test_class_path}, base_dir={base_dir}" + ) logger.warning(f"Could not find the test for file name - {test_class_path} ") continue else: @@ -1271,9 +1275,7 @@ def parse_test_xml( str(test_file.instrumented_behavior_file_path or test_file.original_file_path) for test_file in test_files.test_files ] - logger.info( - f"Tests {test_paths_display} failed to run, skipping" - ) + logger.info(f"Tests {test_paths_display} failed to run, skipping") if run_result is not None: stdout, stderr = "", "" try: diff --git a/codeflash/verification/verifier.py b/codeflash/verification/verifier.py index d80b02013..b677d1819 100644 --- a/codeflash/verification/verifier.py +++ b/codeflash/verification/verifier.py @@ -109,7 +109,11 @@ def generate_tests( # Instrument for behavior verification (renames class) instrumented_behavior_test_source = instrument_generated_java_test( - test_code=generated_test_source, function_name=func_name, qualified_name=qualified_name, mode="behavior", function_to_optimize=function_to_optimize + test_code=generated_test_source, + function_name=func_name, + qualified_name=qualified_name, + mode="behavior", + function_to_optimize=function_to_optimize, ) # Instrument for performance measurement (adds timing markers) @@ -118,7 +122,7 @@ def generate_tests( function_name=func_name, qualified_name=qualified_name, mode="performance", - function_to_optimize=function_to_optimize + function_to_optimize=function_to_optimize, ) logger.debug(f"Instrumented Java tests locally for {func_name}") diff --git a/tests/test_java_assertion_removal.py b/tests/test_java_assertion_removal.py index d0861ee53..7b991db99 100644 --- a/tests/test_java_assertion_removal.py +++ b/tests/test_java_assertion_removal.py @@ -1153,7 +1153,7 @@ void testWithThreadSleep() throws InterruptedException { assert result == expected def test_synchronized_method_signature_preserved(self): - """synchronized modifier on a test method is preserved after transformation.""" + """Synchronized modifier on a test method is preserved after transformation.""" source = """\ @Test synchronized void testSyncMethod() { diff --git a/tests/test_languages/test_java/test_instrumentation.py b/tests/test_languages/test_java/test_instrumentation.py index 499bcc159..d00d6e982 100644 --- a/tests/test_languages/test_java/test_instrumentation.py +++ b/tests/test_languages/test_java/test_instrumentation.py @@ -12,8 +12,6 @@ Also includes end-to-end execution tests that: import os import re -import shutil -import subprocess from pathlib import Path import pytest @@ -24,7 +22,6 @@ os.environ["CODEFLASH_API_KEY"] = "cf-test-key" from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import Language from codeflash.languages.current import set_current_language -from codeflash.models.function_types import FunctionParent from codeflash.languages.java.build_tools import find_maven_executable from codeflash.languages.java.discovery import discover_functions_from_source from codeflash.languages.java.instrumentation import ( @@ -1148,7 +1145,7 @@ public class TargetBenchmark { "public class TargetBenchmark {\n" "\n" " @Test\n" - " @DisplayName(\"Benchmark multiply\")\n" + ' @DisplayName("Benchmark multiply")\n' " public void benchmarkMultiply() {\n" " \n" # Empty test_setup_code with 8-space indent "\n" @@ -1167,7 +1164,7 @@ public class TargetBenchmark { " long totalNanos = endTime - startTime;\n" " long avgNanos = totalNanos / 5000;\n" "\n" - " System.out.println(\"CODEFLASH_BENCHMARK:multiply:total_ns=\" + totalNanos + \",avg_ns=\" + avgNanos + \",iterations=5000\");\n" + ' System.out.println("CODEFLASH_BENCHMARK:multiply:total_ns=" + totalNanos + ",avg_ns=" + avgNanos + ",iterations=5000");\n' " }\n" "}\n" ) @@ -1934,9 +1931,6 @@ class TestRunAndParseTests: @pytest.fixture def java_project(self, tmp_path: Path): """Create a temporary Maven project and set up Java language context.""" - from codeflash.languages.base import Language - from codeflash.languages.current import set_current_language - # Force set the language to Java (reset the singleton first) import codeflash.languages.current as current_module current_module._current_language = None @@ -2432,7 +2426,6 @@ public class BrokenCalcTest { def test_behavior_mode_writes_to_sqlite(self, java_project): """Test that behavior mode correctly writes results to SQLite file.""" import sqlite3 - from argparse import Namespace from codeflash.code_utils.code_utils import get_run_tmp_file diff --git a/tests/test_languages/test_java/test_remove_asserts.py b/tests/test_languages/test_java/test_remove_asserts.py index 9487bd4b4..e0a252ad8 100644 --- a/tests/test_languages/test_java/test_remove_asserts.py +++ b/tests/test_languages/test_java/test_remove_asserts.py @@ -16,10 +16,7 @@ Covers: - Edge cases: static calls, qualified calls, method chaining """ -from codeflash.languages.java.remove_asserts import ( - JavaAssertTransformer, - transform_java_assertions, -) +from codeflash.languages.java.remove_asserts import JavaAssertTransformer, transform_java_assertions class TestJUnit4Assertions: