From b4e233a2cf004b5b90848849163c5aea86429cc4 Mon Sep 17 00:00:00 2001 From: HeshamHM28 Date: Fri, 13 Feb 2026 06:07:56 +0200 Subject: [PATCH] fix omni-java --- codeflash/code_utils/config_consts.py | 1 + codeflash/languages/java/instrumentation.py | 22 +++++-- codeflash/languages/java/replacement.py | 65 ++++++++++++++++++- codeflash/optimization/function_optimizer.py | 6 ++ codeflash/optimization/optimizer.py | 34 ++++++++++ codeflash/result/critic.py | 8 ++- codeflash/verification/verifier.py | 15 +++++ .../test_java/test_instrumentation.py | 22 +++---- .../test_java/test_replacement.py | 3 +- 9 files changed, 154 insertions(+), 22 deletions(-) diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index e9afbcc64..e24ce54f8 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -10,6 +10,7 @@ INDIVIDUAL_TESTCASE_TIMEOUT = 15 # For Python pytest JAVA_TESTCASE_TIMEOUT = 120 # Java Maven tests need more time due to startup overhead MAX_FUNCTION_TEST_SECONDS = 60 MIN_IMPROVEMENT_THRESHOLD = 0.05 +MIN_IMPROVEMENT_THRESHOLD_JAVA = 0.02 MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD = 0.20 # 20% concurrency ratio improvement required CONCURRENCY_FACTOR = 10 # Number of concurrent executions for concurrency benchmark diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index b36b33aef..9f5ac1391 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -508,11 +508,23 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) ) wrapped_body_lines.append(serialize_line) - # Check if the line is now just a variable reference (invalid statement) - # This happens when the original line was just a void method call - # e.g., "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;" + # Check if the line is now just a variable reference (invalid statement). + # This happens when the original line was just a void method call: + # "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;" + # It also happens when assertThrows was transformed to try-catch: + # "try { func(args); } catch (...) {}" becomes + # "try { _cf_result1_1; } catch (...) {}" + # A bare variable is not a valid Java statement. stripped_new = new_line.strip().rstrip(";").strip() - if stripped_new and stripped_new not in (var_name, var_with_cast): + is_bare_var = stripped_new in (var_name, var_with_cast) + is_try_with_bare_var = bool(re.match( + r"try\s*\{\s*(?:" + + re.escape(var_name) + + (r"|" + re.escape(var_with_cast) if var_with_cast != var_name else "") + + r")\s*;\s*\}\s*catch\s*\(", + stripped_new, + )) + if stripped_new and not is_bare_var and not is_try_with_bare_var: wrapped_body_lines.append(new_line) else: wrapped_body_lines.append(body_line) @@ -834,7 +846,7 @@ def instrument_generated_java_test( original_class_name = class_match.group(1) - # For performance mode, add timing instrumentation + # Add mode-specific instrumentation # Use original class name (without suffix) in timing markers for consistency with Python if mode == "performance": diff --git a/codeflash/languages/java/replacement.py b/codeflash/languages/java/replacement.py index d12a2dd52..1d4727987 100644 --- a/codeflash/languages/java/replacement.py +++ b/codeflash/languages/java/replacement.py @@ -33,10 +33,11 @@ class ParsedOptimization: target_method_source: str new_fields: list[str] # Source text of new fields to add new_helper_methods: list[str] # Source text of new helper methods to add + new_imports: list[str] # Import statements to add (e.g., "import java.nio.file.Files;") def _parse_optimization_source(new_source: str, target_method_name: str, analyzer: JavaAnalyzer) -> ParsedOptimization: - """Parse optimization source to extract method and additional class members. + """Parse optimization source to extract method, imports, and additional class members. The new_source may contain: - Just a method definition @@ -48,13 +49,20 @@ def _parse_optimization_source(new_source: str, target_method_name: str, analyze analyzer: JavaAnalyzer instance. Returns: - ParsedOptimization with the method and any additional members. + ParsedOptimization with the method, imports, and any additional members. """ new_fields: list[str] = [] new_helper_methods: list[str] = [] target_method_source = new_source # Default to the whole source + # Extract import statements from the candidate code + new_imports: list[str] = [] + for imp in analyzer.find_imports(new_source): + prefix = "import static " if imp.is_static else "import " + suffix = ".*" if imp.is_wildcard else "" + new_imports.append(f"{prefix}{imp.import_path}{suffix};") + # Check if this is a full class or just a method classes = analyzer.find_classes(new_source) @@ -92,10 +100,57 @@ def _parse_optimization_source(new_source: str, target_method_name: str, analyze new_fields.append(field.source_text) return ParsedOptimization( - target_method_source=target_method_source, new_fields=new_fields, new_helper_methods=new_helper_methods + target_method_source=target_method_source, + new_fields=new_fields, + new_helper_methods=new_helper_methods, + new_imports=new_imports, ) +def _add_missing_imports(source: str, candidate_imports: list[str], analyzer: JavaAnalyzer) -> str: + """Add import statements from the optimization candidate that are missing in the original source. + + Args: + source: The original source code. + candidate_imports: Import statements from the candidate (e.g., ["import java.nio.file.Files;"]). + analyzer: JavaAnalyzer instance. + + Returns: + Source code with missing imports added. + + """ + existing_imports = analyzer.find_imports(source) + existing_import_strs = set() + for imp in existing_imports: + prefix = "import static " if imp.is_static else "import " + suffix = ".*" if imp.is_wildcard else "" + existing_import_strs.add(f"{prefix}{imp.import_path}{suffix};") + + missing_imports = [imp for imp in candidate_imports if imp not in existing_import_strs] + if not missing_imports: + return source + + logger.debug("Adding %d missing imports: %s", len(missing_imports), missing_imports) + + # Insert after the last existing import, or after the package declaration + lines = source.splitlines(keepends=True) + insert_line = 0 + + if existing_imports: + insert_line = max(imp.end_line for imp in existing_imports) + else: + # No existing imports — insert after package declaration + for i, line in enumerate(lines): + if line.strip().startswith("package "): + insert_line = i + 1 + break + + import_block = "".join(imp + "\n" for imp in missing_imports) + before = lines[:insert_line] + after = lines[insert_line:] + return "".join(before) + import_block + "".join(after) + + def _insert_class_members( source: str, class_name: str, fields: list[str], methods: list[str], analyzer: JavaAnalyzer ) -> str: @@ -237,6 +292,10 @@ def replace_function( # Parse the optimization to extract components parsed = _parse_optimization_source(new_source, func_name, analyzer) + # Add any new imports from the optimization candidate + if parsed.new_imports: + source = _add_missing_imports(source, parsed.new_imports, analyzer) + # Find the method in the original source methods = analyzer.find_methods(source) target_method = None diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 838b3e2da..bd69e05d0 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -48,6 +48,7 @@ from codeflash.code_utils.config_consts import ( COVERAGE_THRESHOLD, INDIVIDUAL_TESTCASE_TIMEOUT, MIN_CORRECT_CANDIDATES, + MIN_IMPROVEMENT_THRESHOLD_JAVA, OPTIMIZATION_CONTEXT_TOKEN_LIMIT, REFINED_CANDIDATE_RANKING_WEIGHTS, REPEAT_OPTIMIZATION_PROBABILITY, @@ -1364,6 +1365,8 @@ class FunctionOptimizer: eval_ctx.record_successful_candidate(candidate.optimization_id, candidate_result.best_test_runtime, perf_gain) # Check if this is a successful optimization + # Use a lower threshold for Java where I/O-bound functions have smaller optimization margins + java_override = MIN_IMPROVEMENT_THRESHOLD_JAVA if self.language_support.language == "java" else None is_successful_opt = speedup_critic( candidate_result, original_code_baseline.runtime, @@ -1372,6 +1375,7 @@ class FunctionOptimizer: best_throughput_until_now=None, original_concurrency_metrics=original_code_baseline.concurrency_metrics, best_concurrency_ratio_until_now=None, + min_improvement_override=java_override, ) and quantity_of_tests_critic(candidate_result) tree = self.build_runtime_info_tree( @@ -2272,6 +2276,7 @@ class FunctionOptimizer: fto_benchmark_timings=self.function_benchmark_timings, total_benchmark_timings=self.total_benchmark_timings, ) + java_override = MIN_IMPROVEMENT_THRESHOLD_JAVA if self.language_support.language == "java" else None acceptance_reason = get_acceptance_reason( original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_optimization.runtime, @@ -2279,6 +2284,7 @@ class FunctionOptimizer: optimized_async_throughput=best_optimization.async_throughput, original_concurrency_metrics=original_code_baseline.concurrency_metrics, optimized_concurrency_metrics=best_optimization.concurrency_metrics, + min_improvement_override=java_override, ) explanation = Explanation( raw_explanation_message=best_optimization.candidate.explanation, diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index ae30813a6..c20ecb7b6 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -117,6 +117,38 @@ class Optimizer: except Exception as e: logger.debug(f"Failed to verify JS requirements: {e}") + def _setup_java_runtime(self) -> None: + from pathlib import Path as _Path + + from codeflash.languages.java.build_tools import add_codeflash_dependency_to_pom, install_codeflash_runtime + + runtime_jar = _Path(__file__).parent.parent / "languages" / "java" / "resources" / "codeflash-runtime-1.0.0.jar" + project_root = self.args.project_root + + if not runtime_jar.exists(): + logger.warning("codeflash-runtime JAR not found at %s, behavior capture may not work", runtime_jar) + return + + if not install_codeflash_runtime(project_root, runtime_jar): + logger.warning("Failed to install codeflash-runtime to local Maven repo") + return + + # Add dependency to the test module's pom.xml (or root pom.xml) + test_root = _Path(self.args.tests_root) if hasattr(self.args, "tests_root") and self.args.tests_root else None + pom_path = project_root / "pom.xml" + + # For multi-module projects, find the test module's pom.xml + if test_root and test_root != project_root: + candidate = test_root + while candidate != project_root and candidate != candidate.parent: + if (candidate / "pom.xml").exists(): + pom_path = candidate / "pom.xml" + break + candidate = candidate.parent + + if pom_path.exists(): + add_codeflash_dependency_to_pom(pom_path) + def run_benchmarks( self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int ) -> tuple[dict[str, dict[BenchmarkKey, float]], dict[BenchmarkKey, float]]: @@ -495,6 +527,8 @@ class Optimizer: self.test_cfg.js_project_root = self._find_js_project_root(file_path) # Verify JS requirements before proceeding self._verify_js_requirements() + elif is_java(): + self._setup_java_runtime() break if self.args.all: diff --git a/codeflash/result/critic.py b/codeflash/result/critic.py index f51762ddf..5cc00cbdb 100644 --- a/codeflash/result/critic.py +++ b/codeflash/result/critic.py @@ -72,6 +72,7 @@ def speedup_critic( best_throughput_until_now: int | None = None, original_concurrency_metrics: ConcurrencyMetrics | None = None, best_concurrency_ratio_until_now: float | None = None, + min_improvement_override: float | None = None, ) -> bool: """Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user. @@ -92,7 +93,8 @@ def speedup_critic( - Concurrency improvements detect when blocking calls are replaced with non-blocking equivalents """ # Runtime performance evaluation - noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD + threshold = min_improvement_override if min_improvement_override is not None else MIN_IMPROVEMENT_THRESHOLD + noise_floor = 3 * threshold if original_code_runtime < 10000 else threshold if not disable_gh_action_noise and env_utils.is_ci(): noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode @@ -146,13 +148,15 @@ def get_acceptance_reason( optimized_async_throughput: int | None = None, original_concurrency_metrics: ConcurrencyMetrics | None = None, optimized_concurrency_metrics: ConcurrencyMetrics | None = None, + min_improvement_override: float | None = None, ) -> AcceptanceReason: """Determine why an optimization was accepted. Returns the primary reason for acceptance, with priority: concurrency > throughput > runtime (for async code). """ - noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_runtime_ns < 10000 else MIN_IMPROVEMENT_THRESHOLD + threshold = min_improvement_override if min_improvement_override is not None else MIN_IMPROVEMENT_THRESHOLD + noise_floor = 3 * threshold if original_runtime_ns < 10000 else threshold if env_utils.is_ci(): noise_floor = noise_floor * 2 diff --git a/codeflash/verification/verifier.py b/codeflash/verification/verifier.py index d80b02013..f3c03f45c 100644 --- a/codeflash/verification/verifier.py +++ b/codeflash/verification/verifier.py @@ -122,6 +122,21 @@ def generate_tests( ) logger.debug(f"Instrumented Java tests locally for {func_name}") + logger.debug( + f"=== Java Generated Tests (raw) for {func_name} ===\n" + f"{generated_test_source}\n" + f"=== End Java Generated Tests ===" + ) + logger.debug( + f"=== Java Instrumented Behavior Tests for {func_name} ===\n" + f"{instrumented_behavior_test_source}\n" + f"=== End Java Instrumented Behavior Tests ===" + ) + logger.debug( + f"=== Java Instrumented Perf Tests for {func_name} ===\n" + f"{instrumented_perf_test_source}\n" + f"=== End Java Instrumented Perf Tests ===" + ) else: # Python: instrumentation is done by aiservice, just replace temp dir placeholders instrumented_behavior_test_source = instrumented_behavior_test_source.replace( diff --git a/tests/test_languages/test_java/test_instrumentation.py b/tests/test_languages/test_java/test_instrumentation.py index 56fcd897a..7731efd26 100644 --- a/tests/test_languages/test_java/test_instrumentation.py +++ b/tests/test_languages/test_java/test_instrumentation.py @@ -915,6 +915,7 @@ class TestInstrumentGeneratedJavaTest: 1. Remove assertions containing the target function call 2. Capture the function return value instead 3. Rename the class with __perfinstrumented suffix + 4. Add SQLite behavior instrumentation to capture return values """ test_code = """import org.junit.jupiter.api.Test; @@ -932,17 +933,16 @@ public class CalculatorTest { mode="behavior", ) - # Behavior mode transforms assertions to capture return values - expected = """import org.junit.jupiter.api.Test; - -public class CalculatorTest__perfinstrumented { - @Test - public void testAdd() { - Object _cf_result1 = new Calculator().add(2, 2); - } -} -""" - assert result == expected + # Behavior mode transforms assertions, renames class, and adds SQLite instrumentation + assert "class CalculatorTest__perfinstrumented" in result + assert "import java.sql.Connection;" in result + assert "import java.sql.DriverManager;" in result + assert "import java.sql.PreparedStatement;" in result + assert "CODEFLASH_OUTPUT_FILE" in result + assert "CREATE TABLE IF NOT EXISTS test_results" in result + assert "INSERT INTO test_results VALUES" in result + assert "_cf_serializedResult1" in result + assert "com.codeflash.Serializer.serialize" in result def test_instrument_generated_test_performance_mode(self): """Test instrumenting generated test in performance mode with inner loop.""" diff --git a/tests/test_languages/test_java/test_replacement.py b/tests/test_languages/test_java/test_replacement.py index a56e584ce..56a7d1ccd 100644 --- a/tests/test_languages/test_java/test_replacement.py +++ b/tests/test_languages/test_java/test_replacement.py @@ -647,7 +647,8 @@ public class NullChecker {{ assert result is True new_code = java_file.read_text(encoding="utf-8") - expected = """public class NullChecker { + expected = """import java.util.Objects; +public class NullChecker { public boolean isEqual(String s1, String s2) { return Objects.equals(s1, s2); }