mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix omni-java
This commit is contained in:
parent
ae4eb7c91b
commit
b4e233a2cf
9 changed files with 154 additions and 22 deletions
|
|
@ -10,6 +10,7 @@ INDIVIDUAL_TESTCASE_TIMEOUT = 15 # For Python pytest
|
|||
JAVA_TESTCASE_TIMEOUT = 120 # Java Maven tests need more time due to startup overhead
|
||||
MAX_FUNCTION_TEST_SECONDS = 60
|
||||
MIN_IMPROVEMENT_THRESHOLD = 0.05
|
||||
MIN_IMPROVEMENT_THRESHOLD_JAVA = 0.02
|
||||
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput
|
||||
MIN_CONCURRENCY_IMPROVEMENT_THRESHOLD = 0.20 # 20% concurrency ratio improvement required
|
||||
CONCURRENCY_FACTOR = 10 # Number of concurrent executions for concurrency benchmark
|
||||
|
|
|
|||
|
|
@ -508,11 +508,23 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
|
|||
)
|
||||
wrapped_body_lines.append(serialize_line)
|
||||
|
||||
# Check if the line is now just a variable reference (invalid statement)
|
||||
# This happens when the original line was just a void method call
|
||||
# e.g., "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;"
|
||||
# Check if the line is now just a variable reference (invalid statement).
|
||||
# This happens when the original line was just a void method call:
|
||||
# "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;"
|
||||
# It also happens when assertThrows was transformed to try-catch:
|
||||
# "try { func(args); } catch (...) {}" becomes
|
||||
# "try { _cf_result1_1; } catch (...) {}"
|
||||
# A bare variable is not a valid Java statement.
|
||||
stripped_new = new_line.strip().rstrip(";").strip()
|
||||
if stripped_new and stripped_new not in (var_name, var_with_cast):
|
||||
is_bare_var = stripped_new in (var_name, var_with_cast)
|
||||
is_try_with_bare_var = bool(re.match(
|
||||
r"try\s*\{\s*(?:"
|
||||
+ re.escape(var_name)
|
||||
+ (r"|" + re.escape(var_with_cast) if var_with_cast != var_name else "")
|
||||
+ r")\s*;\s*\}\s*catch\s*\(",
|
||||
stripped_new,
|
||||
))
|
||||
if stripped_new and not is_bare_var and not is_try_with_bare_var:
|
||||
wrapped_body_lines.append(new_line)
|
||||
else:
|
||||
wrapped_body_lines.append(body_line)
|
||||
|
|
@ -834,7 +846,7 @@ def instrument_generated_java_test(
|
|||
original_class_name = class_match.group(1)
|
||||
|
||||
|
||||
# For performance mode, add timing instrumentation
|
||||
# Add mode-specific instrumentation
|
||||
# Use original class name (without suffix) in timing markers for consistency with Python
|
||||
if mode == "performance":
|
||||
|
||||
|
|
|
|||
|
|
@ -33,10 +33,11 @@ class ParsedOptimization:
|
|||
target_method_source: str
|
||||
new_fields: list[str] # Source text of new fields to add
|
||||
new_helper_methods: list[str] # Source text of new helper methods to add
|
||||
new_imports: list[str] # Import statements to add (e.g., "import java.nio.file.Files;")
|
||||
|
||||
|
||||
def _parse_optimization_source(new_source: str, target_method_name: str, analyzer: JavaAnalyzer) -> ParsedOptimization:
|
||||
"""Parse optimization source to extract method and additional class members.
|
||||
"""Parse optimization source to extract method, imports, and additional class members.
|
||||
|
||||
The new_source may contain:
|
||||
- Just a method definition
|
||||
|
|
@ -48,13 +49,20 @@ def _parse_optimization_source(new_source: str, target_method_name: str, analyze
|
|||
analyzer: JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
ParsedOptimization with the method and any additional members.
|
||||
ParsedOptimization with the method, imports, and any additional members.
|
||||
|
||||
"""
|
||||
new_fields: list[str] = []
|
||||
new_helper_methods: list[str] = []
|
||||
target_method_source = new_source # Default to the whole source
|
||||
|
||||
# Extract import statements from the candidate code
|
||||
new_imports: list[str] = []
|
||||
for imp in analyzer.find_imports(new_source):
|
||||
prefix = "import static " if imp.is_static else "import "
|
||||
suffix = ".*" if imp.is_wildcard else ""
|
||||
new_imports.append(f"{prefix}{imp.import_path}{suffix};")
|
||||
|
||||
# Check if this is a full class or just a method
|
||||
classes = analyzer.find_classes(new_source)
|
||||
|
||||
|
|
@ -92,10 +100,57 @@ def _parse_optimization_source(new_source: str, target_method_name: str, analyze
|
|||
new_fields.append(field.source_text)
|
||||
|
||||
return ParsedOptimization(
|
||||
target_method_source=target_method_source, new_fields=new_fields, new_helper_methods=new_helper_methods
|
||||
target_method_source=target_method_source,
|
||||
new_fields=new_fields,
|
||||
new_helper_methods=new_helper_methods,
|
||||
new_imports=new_imports,
|
||||
)
|
||||
|
||||
|
||||
def _add_missing_imports(source: str, candidate_imports: list[str], analyzer: JavaAnalyzer) -> str:
|
||||
"""Add import statements from the optimization candidate that are missing in the original source.
|
||||
|
||||
Args:
|
||||
source: The original source code.
|
||||
candidate_imports: Import statements from the candidate (e.g., ["import java.nio.file.Files;"]).
|
||||
analyzer: JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Source code with missing imports added.
|
||||
|
||||
"""
|
||||
existing_imports = analyzer.find_imports(source)
|
||||
existing_import_strs = set()
|
||||
for imp in existing_imports:
|
||||
prefix = "import static " if imp.is_static else "import "
|
||||
suffix = ".*" if imp.is_wildcard else ""
|
||||
existing_import_strs.add(f"{prefix}{imp.import_path}{suffix};")
|
||||
|
||||
missing_imports = [imp for imp in candidate_imports if imp not in existing_import_strs]
|
||||
if not missing_imports:
|
||||
return source
|
||||
|
||||
logger.debug("Adding %d missing imports: %s", len(missing_imports), missing_imports)
|
||||
|
||||
# Insert after the last existing import, or after the package declaration
|
||||
lines = source.splitlines(keepends=True)
|
||||
insert_line = 0
|
||||
|
||||
if existing_imports:
|
||||
insert_line = max(imp.end_line for imp in existing_imports)
|
||||
else:
|
||||
# No existing imports — insert after package declaration
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip().startswith("package "):
|
||||
insert_line = i + 1
|
||||
break
|
||||
|
||||
import_block = "".join(imp + "\n" for imp in missing_imports)
|
||||
before = lines[:insert_line]
|
||||
after = lines[insert_line:]
|
||||
return "".join(before) + import_block + "".join(after)
|
||||
|
||||
|
||||
def _insert_class_members(
|
||||
source: str, class_name: str, fields: list[str], methods: list[str], analyzer: JavaAnalyzer
|
||||
) -> str:
|
||||
|
|
@ -237,6 +292,10 @@ def replace_function(
|
|||
# Parse the optimization to extract components
|
||||
parsed = _parse_optimization_source(new_source, func_name, analyzer)
|
||||
|
||||
# Add any new imports from the optimization candidate
|
||||
if parsed.new_imports:
|
||||
source = _add_missing_imports(source, parsed.new_imports, analyzer)
|
||||
|
||||
# Find the method in the original source
|
||||
methods = analyzer.find_methods(source)
|
||||
target_method = None
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ from codeflash.code_utils.config_consts import (
|
|||
COVERAGE_THRESHOLD,
|
||||
INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||
MIN_CORRECT_CANDIDATES,
|
||||
MIN_IMPROVEMENT_THRESHOLD_JAVA,
|
||||
OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
|
||||
REFINED_CANDIDATE_RANKING_WEIGHTS,
|
||||
REPEAT_OPTIMIZATION_PROBABILITY,
|
||||
|
|
@ -1364,6 +1365,8 @@ class FunctionOptimizer:
|
|||
eval_ctx.record_successful_candidate(candidate.optimization_id, candidate_result.best_test_runtime, perf_gain)
|
||||
|
||||
# Check if this is a successful optimization
|
||||
# Use a lower threshold for Java where I/O-bound functions have smaller optimization margins
|
||||
java_override = MIN_IMPROVEMENT_THRESHOLD_JAVA if self.language_support.language == "java" else None
|
||||
is_successful_opt = speedup_critic(
|
||||
candidate_result,
|
||||
original_code_baseline.runtime,
|
||||
|
|
@ -1372,6 +1375,7 @@ class FunctionOptimizer:
|
|||
best_throughput_until_now=None,
|
||||
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
|
||||
best_concurrency_ratio_until_now=None,
|
||||
min_improvement_override=java_override,
|
||||
) and quantity_of_tests_critic(candidate_result)
|
||||
|
||||
tree = self.build_runtime_info_tree(
|
||||
|
|
@ -2272,6 +2276,7 @@ class FunctionOptimizer:
|
|||
fto_benchmark_timings=self.function_benchmark_timings,
|
||||
total_benchmark_timings=self.total_benchmark_timings,
|
||||
)
|
||||
java_override = MIN_IMPROVEMENT_THRESHOLD_JAVA if self.language_support.language == "java" else None
|
||||
acceptance_reason = get_acceptance_reason(
|
||||
original_runtime_ns=original_code_baseline.runtime,
|
||||
optimized_runtime_ns=best_optimization.runtime,
|
||||
|
|
@ -2279,6 +2284,7 @@ class FunctionOptimizer:
|
|||
optimized_async_throughput=best_optimization.async_throughput,
|
||||
original_concurrency_metrics=original_code_baseline.concurrency_metrics,
|
||||
optimized_concurrency_metrics=best_optimization.concurrency_metrics,
|
||||
min_improvement_override=java_override,
|
||||
)
|
||||
explanation = Explanation(
|
||||
raw_explanation_message=best_optimization.candidate.explanation,
|
||||
|
|
|
|||
|
|
@ -117,6 +117,38 @@ class Optimizer:
|
|||
except Exception as e:
|
||||
logger.debug(f"Failed to verify JS requirements: {e}")
|
||||
|
||||
def _setup_java_runtime(self) -> None:
|
||||
from pathlib import Path as _Path
|
||||
|
||||
from codeflash.languages.java.build_tools import add_codeflash_dependency_to_pom, install_codeflash_runtime
|
||||
|
||||
runtime_jar = _Path(__file__).parent.parent / "languages" / "java" / "resources" / "codeflash-runtime-1.0.0.jar"
|
||||
project_root = self.args.project_root
|
||||
|
||||
if not runtime_jar.exists():
|
||||
logger.warning("codeflash-runtime JAR not found at %s, behavior capture may not work", runtime_jar)
|
||||
return
|
||||
|
||||
if not install_codeflash_runtime(project_root, runtime_jar):
|
||||
logger.warning("Failed to install codeflash-runtime to local Maven repo")
|
||||
return
|
||||
|
||||
# Add dependency to the test module's pom.xml (or root pom.xml)
|
||||
test_root = _Path(self.args.tests_root) if hasattr(self.args, "tests_root") and self.args.tests_root else None
|
||||
pom_path = project_root / "pom.xml"
|
||||
|
||||
# For multi-module projects, find the test module's pom.xml
|
||||
if test_root and test_root != project_root:
|
||||
candidate = test_root
|
||||
while candidate != project_root and candidate != candidate.parent:
|
||||
if (candidate / "pom.xml").exists():
|
||||
pom_path = candidate / "pom.xml"
|
||||
break
|
||||
candidate = candidate.parent
|
||||
|
||||
if pom_path.exists():
|
||||
add_codeflash_dependency_to_pom(pom_path)
|
||||
|
||||
def run_benchmarks(
|
||||
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int
|
||||
) -> tuple[dict[str, dict[BenchmarkKey, float]], dict[BenchmarkKey, float]]:
|
||||
|
|
@ -495,6 +527,8 @@ class Optimizer:
|
|||
self.test_cfg.js_project_root = self._find_js_project_root(file_path)
|
||||
# Verify JS requirements before proceeding
|
||||
self._verify_js_requirements()
|
||||
elif is_java():
|
||||
self._setup_java_runtime()
|
||||
break
|
||||
|
||||
if self.args.all:
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ def speedup_critic(
|
|||
best_throughput_until_now: int | None = None,
|
||||
original_concurrency_metrics: ConcurrencyMetrics | None = None,
|
||||
best_concurrency_ratio_until_now: float | None = None,
|
||||
min_improvement_override: float | None = None,
|
||||
) -> bool:
|
||||
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
|
||||
|
||||
|
|
@ -92,7 +93,8 @@ def speedup_critic(
|
|||
- Concurrency improvements detect when blocking calls are replaced with non-blocking equivalents
|
||||
"""
|
||||
# Runtime performance evaluation
|
||||
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
|
||||
threshold = min_improvement_override if min_improvement_override is not None else MIN_IMPROVEMENT_THRESHOLD
|
||||
noise_floor = 3 * threshold if original_code_runtime < 10000 else threshold
|
||||
if not disable_gh_action_noise and env_utils.is_ci():
|
||||
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
|
||||
|
||||
|
|
@ -146,13 +148,15 @@ def get_acceptance_reason(
|
|||
optimized_async_throughput: int | None = None,
|
||||
original_concurrency_metrics: ConcurrencyMetrics | None = None,
|
||||
optimized_concurrency_metrics: ConcurrencyMetrics | None = None,
|
||||
min_improvement_override: float | None = None,
|
||||
) -> AcceptanceReason:
|
||||
"""Determine why an optimization was accepted.
|
||||
|
||||
Returns the primary reason for acceptance, with priority:
|
||||
concurrency > throughput > runtime (for async code).
|
||||
"""
|
||||
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_runtime_ns < 10000 else MIN_IMPROVEMENT_THRESHOLD
|
||||
threshold = min_improvement_override if min_improvement_override is not None else MIN_IMPROVEMENT_THRESHOLD
|
||||
noise_floor = 3 * threshold if original_runtime_ns < 10000 else threshold
|
||||
if env_utils.is_ci():
|
||||
noise_floor = noise_floor * 2
|
||||
|
||||
|
|
|
|||
|
|
@ -122,6 +122,21 @@ def generate_tests(
|
|||
)
|
||||
|
||||
logger.debug(f"Instrumented Java tests locally for {func_name}")
|
||||
logger.debug(
|
||||
f"=== Java Generated Tests (raw) for {func_name} ===\n"
|
||||
f"{generated_test_source}\n"
|
||||
f"=== End Java Generated Tests ==="
|
||||
)
|
||||
logger.debug(
|
||||
f"=== Java Instrumented Behavior Tests for {func_name} ===\n"
|
||||
f"{instrumented_behavior_test_source}\n"
|
||||
f"=== End Java Instrumented Behavior Tests ==="
|
||||
)
|
||||
logger.debug(
|
||||
f"=== Java Instrumented Perf Tests for {func_name} ===\n"
|
||||
f"{instrumented_perf_test_source}\n"
|
||||
f"=== End Java Instrumented Perf Tests ==="
|
||||
)
|
||||
else:
|
||||
# Python: instrumentation is done by aiservice, just replace temp dir placeholders
|
||||
instrumented_behavior_test_source = instrumented_behavior_test_source.replace(
|
||||
|
|
|
|||
|
|
@ -915,6 +915,7 @@ class TestInstrumentGeneratedJavaTest:
|
|||
1. Remove assertions containing the target function call
|
||||
2. Capture the function return value instead
|
||||
3. Rename the class with __perfinstrumented suffix
|
||||
4. Add SQLite behavior instrumentation to capture return values
|
||||
"""
|
||||
test_code = """import org.junit.jupiter.api.Test;
|
||||
|
||||
|
|
@ -932,17 +933,16 @@ public class CalculatorTest {
|
|||
mode="behavior",
|
||||
)
|
||||
|
||||
# Behavior mode transforms assertions to capture return values
|
||||
expected = """import org.junit.jupiter.api.Test;
|
||||
|
||||
public class CalculatorTest__perfinstrumented {
|
||||
@Test
|
||||
public void testAdd() {
|
||||
Object _cf_result1 = new Calculator().add(2, 2);
|
||||
}
|
||||
}
|
||||
"""
|
||||
assert result == expected
|
||||
# Behavior mode transforms assertions, renames class, and adds SQLite instrumentation
|
||||
assert "class CalculatorTest__perfinstrumented" in result
|
||||
assert "import java.sql.Connection;" in result
|
||||
assert "import java.sql.DriverManager;" in result
|
||||
assert "import java.sql.PreparedStatement;" in result
|
||||
assert "CODEFLASH_OUTPUT_FILE" in result
|
||||
assert "CREATE TABLE IF NOT EXISTS test_results" in result
|
||||
assert "INSERT INTO test_results VALUES" in result
|
||||
assert "_cf_serializedResult1" in result
|
||||
assert "com.codeflash.Serializer.serialize" in result
|
||||
|
||||
def test_instrument_generated_test_performance_mode(self):
|
||||
"""Test instrumenting generated test in performance mode with inner loop."""
|
||||
|
|
|
|||
|
|
@ -647,7 +647,8 @@ public class NullChecker {{
|
|||
|
||||
assert result is True
|
||||
new_code = java_file.read_text(encoding="utf-8")
|
||||
expected = """public class NullChecker {
|
||||
expected = """import java.util.Objects;
|
||||
public class NullChecker {
|
||||
public boolean isEqual(String s1, String s2) {
|
||||
return Objects.equals(s1, s2);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue