fix omni-java

This commit is contained in:
HeshamHM28 2026-02-13 06:07:56 +02:00
parent ae4eb7c91b
commit b4e233a2cf
9 changed files with 154 additions and 22 deletions

View file

@ -10,6 +10,7 @@ INDIVIDUAL_TESTCASE_TIMEOUT = 15 # For Python pytest
JAVA_TESTCASE_TIMEOUT = 120 # Java Maven tests need more time due to startup overhead
MAX_FUNCTION_TEST_SECONDS = 60
MIN_IMPROVEMENT_THRESHOLD = 0.05
MIN_IMPROVEMENT_THRESHOLD_JAVA = 0.02
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput
MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD = 0.20 # 20% concurrency ratio improvement required
CONCURRENCY_FACTOR = 10 # Number of concurrent executions for concurrency benchmark

View file

@ -508,11 +508,23 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
)
wrapped_body_lines.append(serialize_line)
# Check if the line is now just a variable reference (invalid statement)
# This happens when the original line was just a void method call
# e.g., "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;"
# Check if the line is now just a variable reference (invalid statement).
# This happens when the original line was just a void method call:
# "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;"
# It also happens when assertThrows was transformed to try-catch:
# "try { func(args); } catch (...) {}" becomes
# "try { _cf_result1_1; } catch (...) {}"
# A bare variable is not a valid Java statement.
stripped_new = new_line.strip().rstrip(";").strip()
if stripped_new and stripped_new not in (var_name, var_with_cast):
is_bare_var = stripped_new in (var_name, var_with_cast)
is_try_with_bare_var = bool(re.match(
r"try\s*\{\s*(?:"
+ re.escape(var_name)
+ (r"|" + re.escape(var_with_cast) if var_with_cast != var_name else "")
+ r")\s*;\s*\}\s*catch\s*\(",
stripped_new,
))
if stripped_new and not is_bare_var and not is_try_with_bare_var:
wrapped_body_lines.append(new_line)
else:
wrapped_body_lines.append(body_line)
@ -834,7 +846,7 @@ def instrument_generated_java_test(
original_class_name = class_match.group(1)
# For performance mode, add timing instrumentation
# Add mode-specific instrumentation
# Use original class name (without suffix) in timing markers for consistency with Python
if mode == "performance":

View file

@ -33,10 +33,11 @@ class ParsedOptimization:
target_method_source: str
new_fields: list[str] # Source text of new fields to add
new_helper_methods: list[str] # Source text of new helper methods to add
new_imports: list[str] # Import statements to add (e.g., "import java.nio.file.Files;")
def _parse_optimization_source(new_source: str, target_method_name: str, analyzer: JavaAnalyzer) -> ParsedOptimization:
"""Parse optimization source to extract method and additional class members.
"""Parse optimization source to extract method, imports, and additional class members.
The new_source may contain:
- Just a method definition
@ -48,13 +49,20 @@ def _parse_optimization_source(new_source: str, target_method_name: str, analyze
analyzer: JavaAnalyzer instance.
Returns:
ParsedOptimization with the method and any additional members.
ParsedOptimization with the method, imports, and any additional members.
"""
new_fields: list[str] = []
new_helper_methods: list[str] = []
target_method_source = new_source # Default to the whole source
# Extract import statements from the candidate code
new_imports: list[str] = []
for imp in analyzer.find_imports(new_source):
prefix = "import static " if imp.is_static else "import "
suffix = ".*" if imp.is_wildcard else ""
new_imports.append(f"{prefix}{imp.import_path}{suffix};")
# Check if this is a full class or just a method
classes = analyzer.find_classes(new_source)
@ -92,10 +100,57 @@ def _parse_optimization_source(new_source: str, target_method_name: str, analyze
new_fields.append(field.source_text)
return ParsedOptimization(
target_method_source=target_method_source, new_fields=new_fields, new_helper_methods=new_helper_methods
target_method_source=target_method_source,
new_fields=new_fields,
new_helper_methods=new_helper_methods,
new_imports=new_imports,
)
def _add_missing_imports(source: str, candidate_imports: list[str], analyzer: JavaAnalyzer) -> str:
"""Add import statements from the optimization candidate that are missing in the original source.
Args:
source: The original source code.
candidate_imports: Import statements from the candidate (e.g., ["import java.nio.file.Files;"]).
analyzer: JavaAnalyzer instance.
Returns:
Source code with missing imports added.
"""
existing_imports = analyzer.find_imports(source)
existing_import_strs = set()
for imp in existing_imports:
prefix = "import static " if imp.is_static else "import "
suffix = ".*" if imp.is_wildcard else ""
existing_import_strs.add(f"{prefix}{imp.import_path}{suffix};")
missing_imports = [imp for imp in candidate_imports if imp not in existing_import_strs]
if not missing_imports:
return source
logger.debug("Adding %d missing imports: %s", len(missing_imports), missing_imports)
# Insert after the last existing import, or after the package declaration
lines = source.splitlines(keepends=True)
insert_line = 0
if existing_imports:
insert_line = max(imp.end_line for imp in existing_imports)
else:
# No existing imports — insert after package declaration
for i, line in enumerate(lines):
if line.strip().startswith("package "):
insert_line = i + 1
break
import_block = "".join(imp + "\n" for imp in missing_imports)
before = lines[:insert_line]
after = lines[insert_line:]
return "".join(before) + import_block + "".join(after)
def _insert_class_members(
source: str, class_name: str, fields: list[str], methods: list[str], analyzer: JavaAnalyzer
) -> str:
@ -237,6 +292,10 @@ def replace_function(
# Parse the optimization to extract components
parsed = _parse_optimization_source(new_source, func_name, analyzer)
# Add any new imports from the optimization candidate
if parsed.new_imports:
source = _add_missing_imports(source, parsed.new_imports, analyzer)
# Find the method in the original source
methods = analyzer.find_methods(source)
target_method = None

View file

@ -48,6 +48,7 @@ from codeflash.code_utils.config_consts import (
COVERAGE_THRESHOLD,
INDIVIDUAL_TESTCASE_TIMEOUT,
MIN_CORRECT_CANDIDATES,
MIN_IMPROVEMENT_THRESHOLD_JAVA,
OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
REFINED_CANDIDATE_RANKING_WEIGHTS,
REPEAT_OPTIMIZATION_PROBABILITY,
@ -1364,6 +1365,8 @@ class FunctionOptimizer:
eval_ctx.record_successful_candidate(candidate.optimization_id, candidate_result.best_test_runtime, perf_gain)
# Check if this is a successful optimization
# Use a lower threshold for Java where I/O-bound functions have smaller optimization margins
java_override = MIN_IMPROVEMENT_THRESHOLD_JAVA if self.language_support.language == "java" else None
is_successful_opt = speedup_critic(
candidate_result,
original_code_baseline.runtime,
@ -1372,6 +1375,7 @@ class FunctionOptimizer:
best_throughput_until_now=None,
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
best_concurrency_ratio_until_now=None,
min_improvement_override=java_override,
) and quantity_of_tests_critic(candidate_result)
tree = self.build_runtime_info_tree(
@ -2272,6 +2276,7 @@ class FunctionOptimizer:
fto_benchmark_timings=self.function_benchmark_timings,
total_benchmark_timings=self.total_benchmark_timings,
)
java_override = MIN_IMPROVEMENT_THRESHOLD_JAVA if self.language_support.language == "java" else None
acceptance_reason = get_acceptance_reason(
original_runtime_ns=original_code_baseline.runtime,
optimized_runtime_ns=best_optimization.runtime,
@ -2279,6 +2284,7 @@ class FunctionOptimizer:
optimized_async_throughput=best_optimization.async_throughput,
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
optimized_concurrency_metrics=best_optimization.concurrency_metrics,
min_improvement_override=java_override,
)
explanation = Explanation(
raw_explanation_message=best_optimization.candidate.explanation,

View file

@ -117,6 +117,38 @@ class Optimizer:
except Exception as e:
logger.debug(f"Failed to verify JS requirements: {e}")
def _setup_java_runtime(self) -> None:
from pathlib import Path as _Path
from codeflash.languages.java.build_tools import add_codeflash_dependency_to_pom, install_codeflash_runtime
runtime_jar = _Path(__file__).parent.parent / "languages" / "java" / "resources" / "codeflash-runtime-1.0.0.jar"
project_root = self.args.project_root
if not runtime_jar.exists():
logger.warning("codeflash-runtime JAR not found at %s, behavior capture may not work", runtime_jar)
return
if not install_codeflash_runtime(project_root, runtime_jar):
logger.warning("Failed to install codeflash-runtime to local Maven repo")
return
# Add dependency to the test module's pom.xml (or root pom.xml)
test_root = _Path(self.args.tests_root) if hasattr(self.args, "tests_root") and self.args.tests_root else None
pom_path = project_root / "pom.xml"
# For multi-module projects, find the test module's pom.xml
if test_root and test_root != project_root:
candidate = test_root
while candidate != project_root and candidate != candidate.parent:
if (candidate / "pom.xml").exists():
pom_path = candidate / "pom.xml"
break
candidate = candidate.parent
if pom_path.exists():
add_codeflash_dependency_to_pom(pom_path)
def run_benchmarks(
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int
) -> tuple[dict[str, dict[BenchmarkKey, float]], dict[BenchmarkKey, float]]:
@ -495,6 +527,8 @@ class Optimizer:
self.test_cfg.js_project_root = self._find_js_project_root(file_path)
# Verify JS requirements before proceeding
self._verify_js_requirements()
elif is_java():
self._setup_java_runtime()
break
if self.args.all:

View file

@ -72,6 +72,7 @@ def speedup_critic(
best_throughput_until_now: int | None = None,
original_concurrency_metrics: ConcurrencyMetrics | None = None,
best_concurrency_ratio_until_now: float | None = None,
min_improvement_override: float | None = None,
) -> bool:
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
@ -92,7 +93,8 @@ def speedup_critic(
- Concurrency improvements detect when blocking calls are replaced with non-blocking equivalents
"""
# Runtime performance evaluation
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
threshold = min_improvement_override if min_improvement_override is not None else MIN_IMPROVEMENT_THRESHOLD
noise_floor = 3 * threshold if original_code_runtime < 10000 else threshold
if not disable_gh_action_noise and env_utils.is_ci():
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
@ -146,13 +148,15 @@ def get_acceptance_reason(
optimized_async_throughput: int | None = None,
original_concurrency_metrics: ConcurrencyMetrics | None = None,
optimized_concurrency_metrics: ConcurrencyMetrics | None = None,
min_improvement_override: float | None = None,
) -> AcceptanceReason:
"""Determine why an optimization was accepted.
Returns the primary reason for acceptance, with priority:
concurrency > throughput > runtime (for async code).
"""
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_runtime_ns < 10000 else MIN_IMPROVEMENT_THRESHOLD
threshold = min_improvement_override if min_improvement_override is not None else MIN_IMPROVEMENT_THRESHOLD
noise_floor = 3 * threshold if original_runtime_ns < 10000 else threshold
if env_utils.is_ci():
noise_floor = noise_floor * 2

View file

@ -122,6 +122,21 @@ def generate_tests(
)
logger.debug(f"Instrumented Java tests locally for {func_name}")
logger.debug(
f"=== Java Generated Tests (raw) for {func_name} ===\n"
f"{generated_test_source}\n"
f"=== End Java Generated Tests ==="
)
logger.debug(
f"=== Java Instrumented Behavior Tests for {func_name} ===\n"
f"{instrumented_behavior_test_source}\n"
f"=== End Java Instrumented Behavior Tests ==="
)
logger.debug(
f"=== Java Instrumented Perf Tests for {func_name} ===\n"
f"{instrumented_perf_test_source}\n"
f"=== End Java Instrumented Perf Tests ==="
)
else:
# Python: instrumentation is done by aiservice, just replace temp dir placeholders
instrumented_behavior_test_source = instrumented_behavior_test_source.replace(

View file

@ -915,6 +915,7 @@ class TestInstrumentGeneratedJavaTest:
1. Remove assertions containing the target function call
2. Capture the function return value instead
3. Rename the class with __perfinstrumented suffix
4. Add SQLite behavior instrumentation to capture return values
"""
test_code = """import org.junit.jupiter.api.Test;
@ -932,17 +933,16 @@ public class CalculatorTest {
mode="behavior",
)
# Behavior mode transforms assertions to capture return values
expected = """import org.junit.jupiter.api.Test;
public class CalculatorTest__perfinstrumented {
@Test
public void testAdd() {
Object _cf_result1 = new Calculator().add(2, 2);
}
}
"""
assert result == expected
# Behavior mode transforms assertions, renames class, and adds SQLite instrumentation
assert "class CalculatorTest__perfinstrumented" in result
assert "import java.sql.Connection;" in result
assert "import java.sql.DriverManager;" in result
assert "import java.sql.PreparedStatement;" in result
assert "CODEFLASH_OUTPUT_FILE" in result
assert "CREATE TABLE IF NOT EXISTS test_results" in result
assert "INSERT INTO test_results VALUES" in result
assert "_cf_serializedResult1" in result
assert "com.codeflash.Serializer.serialize" in result
def test_instrument_generated_test_performance_mode(self):
"""Test instrumenting generated test in performance mode with inner loop."""

View file

@ -647,7 +647,8 @@ public class NullChecker {{
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class NullChecker {
expected = """import java.util.Objects;
public class NullChecker {
public boolean isEqual(String s1, String s2) {
return Objects.equals(s1, s2);
}