Merge remote-tracking branch 'origin/omni-java' into merge/misc-fixes-into-omni-java

# Conflicts:
#	codeflash/api/aiservice.py
#	codeflash/languages/base.py
#	codeflash/languages/java/support.py
#	codeflash/languages/javascript/support.py
#	codeflash/languages/python/support.py
#	codeflash/verification/verifier.py
This commit is contained in:
Kevin Turcios 2026-03-04 01:23:39 -05:00
commit 2fb0145895
11 changed files with 141 additions and 65 deletions

View file

@ -64,7 +64,7 @@ Core protocol in `languages/base.py`. Each language (`PythonSupport`, `JavaScrip
|----------|----------------|---------|
| Identity | `language`, `file_extensions`, `default_file_extension` | Language identification |
| Identity | `comment_prefix`, `dir_excludes` | Language conventions |
| AI service | `default_language_version` | Language version for API payloads (`None` for Python, `"ES2022"` for JS) |
| AI service | `language_version` | Detected language version for API payloads (e.g., `"3.11.0"` for Python, `"17"` for Java) |
| AI service | `valid_test_frameworks` | Allowed test frameworks for validation |
| Discovery | `discover_functions`, `discover_tests` | Find optimizable functions and their tests |
| Discovery | `adjust_test_config_for_discovery` | Pre-discovery config adjustment (no-op default) |

View file

@ -57,10 +57,12 @@ class AiServiceClient:
payload: dict[str, Any], language_version: str | None = None, module_system: str | None = None
) -> None:
"""Add language version and module system metadata to an API payload."""
payload["python_version"] = platform.python_version()
default_lang_version = current_language_support().default_language_version
if default_lang_version is not None:
payload["language_version"] = language_version or default_lang_version
# Canonical for all languages
payload["language_version"] = language_version
# Backward compat: Python backend still expects python_version
payload["python_version"] = language_version if current_language() == Language.PYTHON else None
if current_language() != Language.PYTHON:
if module_system:
payload["module_system"] = module_system
@ -142,8 +144,7 @@ class AiServiceClient:
experiment_metadata: ExperimentMetadata | None = None,
*,
language: str = "python",
language_version: str
| None = None, # TODO:{claude} add language version to the language support and it should be cached
language_version: str | None = None,
module_system: str | None = None,
is_async: bool = False,
n_candidates: int = 5,
@ -264,7 +265,7 @@ class AiServiceClient:
"source_code": source_code,
"trace_id": trace_id,
"dependency_code": "", # dummy value to please the api endpoint
"python_version": "3.12.1", # dummy value to please the api endpoint
"python_version": platform.python_version(), # backward compat
"current_username": get_last_commit_author_if_pr_exists(None),
"repo_owner": git_repo_owner,
"repo_name": git_repo_name,
@ -331,18 +332,15 @@ class AiServiceClient:
logger.info("Generating optimized candidates with line profiler…")
console.rule()
# Set python_version for backward compatibility with Python, or use language_version
python_version = language_version if language_version else platform.python_version()
payload = {
"source_code": source_code,
"dependency_code": dependency_code,
"n_candidates": n_candidates,
"line_profiler_results": line_profiler_results,
"trace_id": trace_id,
"python_version": python_version,
"language": language,
"language_version": language_version,
"python_version": language_version if current_language() == Language.PYTHON else None,
"experiment_metadata": experiment_metadata,
"codeflash_version": codeflash_version,
"call_sequence": self.get_next_sequence(),
@ -644,7 +642,7 @@ class AiServiceClient:
"diffs": diffs,
"speedups": speedups,
"optimization_ids": optimization_ids,
"python_version": platform.python_version(),
"python_version": platform.python_version(), # backward compat
"function_references": function_references,
}
logger.info("loading|Generating ranking")
@ -770,6 +768,8 @@ class AiServiceClient:
"is_async": function_to_optimize.is_async,
"call_sequence": self.get_next_sequence(),
"is_numerical_code": is_numerical_code,
"class_name": function_to_optimize.class_name,
"qualified_name": function_to_optimize.qualified_name,
}
self.add_language_metadata(payload, language_version, module_system)
@ -858,6 +858,7 @@ class AiServiceClient:
"codeflash_version": codeflash_version,
"calling_fn_details": calling_fn_details,
"language": language,
"language_version": platform.python_version() if current_language() == Language.PYTHON else None,
"python_version": platform.python_version() if current_language() == Language.PYTHON else None,
"call_sequence": self.get_next_sequence(),
}

View file

@ -137,14 +137,15 @@ class OptimizeRequest:
"is_numerical_code": self.is_numerical_code,
}
# Add language-specific fields
if self.language_info.version:
payload["language_version"] = self.language_info.version
# Add language version (canonical for all languages)
payload["language_version"] = self.language_info.version
# Backward compat: always include python_version
# Backward compat: backend still expects python_version
import platform
payload["python_version"] = platform.python_version()
payload["python_version"] = (
self.language_info.version if self.language_info.name == "python" else platform.python_version()
)
# Module system for JS/TS
if self.language_info.module_system != ModuleSystem.UNKNOWN:
@ -205,14 +206,15 @@ class TestGenRequest:
"is_numerical_code": self.is_numerical_code,
}
# Add language version
if self.language_info.version:
payload["language_version"] = self.language_info.version
# Add language version (canonical for all languages)
payload["language_version"] = self.language_info.version
# Backward compat: always include python_version
# Backward compat: backend still expects python_version
import platform
payload["python_version"] = platform.python_version()
payload["python_version"] = (
self.language_info.version if self.language_info.name == "python" else platform.python_version()
)
# Module system for JS/TS
if self.language_info.module_system != ModuleSystem.UNKNOWN:

View file

@ -326,12 +326,8 @@ class LanguageSupport(Protocol):
...
@property
def default_language_version(self) -> str | None:
"""Default language version string sent to AI service.
Returns None for languages where the runtime version is auto-detected (e.g. Python).
Returns a version string (e.g. "ES2022") for languages that need an explicit default.
"""
def language_version(self) -> str | None:
"""The detected language version (e.g., "17" for Java, "ES2022" for JS)."""
...
@property
@ -900,6 +896,31 @@ class LanguageSupport(Protocol):
"""
return {}, ""
def run_line_profile_tests(
self,
test_paths: Any,
test_env: dict[str, str],
cwd: Path,
timeout: int | None = None,
project_root: Path | None = None,
line_profile_output_file: Path | None = None,
) -> tuple[Path, Any]:
"""Run tests for line profiling.
Args:
test_paths: TestFiles object containing test file information.
test_env: Environment variables for the test run.
cwd: Working directory for running tests.
timeout: Optional timeout in seconds.
project_root: Project root directory.
line_profile_output_file: Path where line profile results will be written.
Returns:
Tuple of (result_file_path, subprocess_result).
"""
...
def run_behavioral_tests(
self,
test_paths: Any,
@ -958,32 +979,6 @@ class LanguageSupport(Protocol):
"""
...
def run_line_profile_tests(
self,
test_paths: Any,
test_env: dict[str, str],
cwd: Path,
timeout: int | None = None,
project_root: Path | None = None,
line_profile_output_file: Path | None = None,
) -> tuple[Path, Any]:
"""Run tests for line profiling.
Args:
test_paths: TestFiles object containing test file information.
test_env: Environment variables for the test run.
cwd: Working directory for running tests.
timeout: Optional timeout in seconds.
project_root: Project root directory.
line_profile_output_file: Path where line profile results will be written.
Returns:
Tuple of (result_file_path, subprocess_result).
"""
...
def convert_parents_to_tuple(parents: list | tuple) -> tuple[FunctionParent, ...]:
"""Convert a list of parent objects to a tuple of FunctionParent.

View file

@ -258,6 +258,7 @@ def wrap_target_calls_with_treesitter(
precise_call_timing: bool = False,
class_name: str = "",
test_method_name: str = "",
target_return_type: str = "",
) -> tuple[list[str], int]:
"""Replace target method calls in body_lines with capture + serialize using tree-sitter.
@ -327,6 +328,8 @@ def wrap_target_calls_with_treesitter(
call_counter += 1
var_name = f"_cf_result{iter_id}_{call_counter}"
cast_type = _infer_array_cast_type(body_line)
if not cast_type and target_return_type and target_return_type != "void":
cast_type = target_return_type
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name
# Use per-call unique variables (with call_counter suffix) for behavior mode
@ -524,6 +527,26 @@ def _infer_array_cast_type(line: str) -> str | None:
return None
def _extract_return_type(function_to_optimize: Any) -> str:
"""Extract the return type of a Java function from its source file using tree-sitter."""
file_path = getattr(function_to_optimize, "file_path", None)
func_name = _get_function_name(function_to_optimize)
if not file_path or not file_path.exists():
return ""
try:
from codeflash.languages.java.parser import get_java_analyzer
analyzer = get_java_analyzer()
source_text = file_path.read_text(encoding="utf-8")
methods = analyzer.find_methods(source_text)
for method in methods:
if method.name == func_name and method.return_type:
return method.return_type
except Exception:
logger.debug("Could not extract return type for %s", func_name)
return ""
def _get_qualified_name(func: Any) -> str:
"""Get the qualified name from FunctionToOptimize."""
if hasattr(func, "qualified_name"):
@ -617,6 +640,7 @@ def instrument_existing_test(
"""
source = test_string
func_name = _get_function_name(function_to_optimize)
target_return_type = _extract_return_type(function_to_optimize)
# Get the original class name from the file name
if test_path:
@ -654,14 +678,16 @@ def instrument_existing_test(
)
else:
# Behavior mode: add timing instrumentation that also writes to SQLite
modified_source = _add_behavior_instrumentation(modified_source, original_class_name, func_name)
modified_source = _add_behavior_instrumentation(
modified_source, original_class_name, func_name, target_return_type
)
logger.debug("Java %s testing for %s: renamed class %s -> %s", mode, func_name, original_class_name, new_class_name)
# Why return True here?
return True, modified_source
def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) -> str:
def _add_behavior_instrumentation(source: str, class_name: str, func_name: str, target_return_type: str = "") -> str:
"""Add behavior instrumentation to test methods.
For behavior mode, this adds:
@ -802,6 +828,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
precise_call_timing=True,
class_name=class_name,
test_method_name=test_method_name,
target_return_type=target_return_type,
)
# Add behavior instrumentation setup code (shared variables for all calls in the method)

View file

@ -64,6 +64,7 @@ class JavaSupport(LanguageSupport):
self._analyzer = get_java_analyzer()
self.line_profiler_agent_arg: str | None = None
self.line_profiler_warmup_iterations: int = 0
self._language_version: str | None = None
@property
def language(self) -> Language:
@ -94,8 +95,8 @@ class JavaSupport(LanguageSupport):
return frozenset({"target", "build", ".gradle", ".mvn", ".idea", "apidocs", "javadoc"})
@property
def default_language_version(self) -> str | None:
return "17"
def language_version(self) -> str | None:
return self._language_version
@property
def valid_test_frameworks(self) -> tuple[str, ...]:
@ -498,10 +499,39 @@ class JavaSupport(LanguageSupport):
if config is None:
return False
self._language_version = config.java_version
if self._language_version is None:
self._detect_java_version()
# For now, assume the runtime is available
# A full implementation would check/install the JAR
return True
def _detect_java_version(self) -> None:
"""Detect and cache the Java runtime version."""
if self._language_version is not None:
return
import subprocess
try:
result = subprocess.run(["java", "-version"], check=False, capture_output=True, text=True, timeout=10)
# java -version outputs to stderr, e.g. 'openjdk version "17.0.2"'
output = result.stderr or result.stdout
for line in output.splitlines():
if "version" in line:
# Extract version between quotes: "17.0.2" -> "17"
start = line.find('"')
end = line.find('"', start + 1)
if start != -1 and end != -1:
full_version = line[start + 1 : end]
# Use major version only: "17.0.2" -> "17", "1.8.0_292" -> "8"
major = full_version.split(".")[0]
self._language_version = "8" if major == "1" else major
return
except Exception:
pass
def instrument_existing_test(
self,
test_path: Path,

View file

@ -37,6 +37,9 @@ class JavaScriptSupport:
using tree-sitter for code analysis and Jest for test execution.
"""
def __init__(self) -> None:
self._language_version: str | None = None
# === Properties ===
@property
@ -69,8 +72,8 @@ class JavaScriptSupport:
return frozenset({"node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache", ".turbo", ".vercel"})
@property
def default_language_version(self) -> str | None:
return "ES2022"
def language_version(self) -> str | None:
return self._language_version
@property
def valid_test_frameworks(self) -> tuple[str, ...]:
@ -2240,6 +2243,15 @@ class JavaScriptSupport:
return len(errors) == 0, errors
def _detect_node_version(self) -> None:
"""Detect and cache the Node.js runtime version."""
try:
result = subprocess.run(["node", "--version"], check=False, capture_output=True, text=True, timeout=10)
if result.returncode == 0 and result.stdout.strip():
self._language_version = result.stdout.strip().lstrip("v")
except Exception:
pass
def ensure_runtime_environment(self, project_root: Path) -> bool:
"""Ensure codeflash npm package is installed.
@ -2254,6 +2266,8 @@ class JavaScriptSupport:
"""
from codeflash.cli_cmds.console import logger
self._detect_node_version()
node_modules_pkg = project_root / "node_modules" / "codeflash"
if node_modules_pkg.exists():
logger.debug("codeflash already installed")

View file

@ -3,6 +3,7 @@
from __future__ import annotations
import logging
import platform
from pathlib import Path
from typing import TYPE_CHECKING, Any
@ -180,8 +181,8 @@ class PythonSupport:
)
@property
def default_language_version(self) -> str | None:
return None
def language_version(self) -> str | None:
return platform.python_version()
@property
def valid_test_frameworks(self) -> tuple[str, ...]:

View file

@ -1216,6 +1216,7 @@ class FunctionOptimizer:
optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"],
function_references=function_references,
language=self.function_to_optimize.language,
language_version=self.language_support.language_version,
)
],
)
@ -1277,6 +1278,7 @@ class FunctionOptimizer:
else None,
is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts,
language=self.function_to_optimize.language,
language_version=self.language_support.language_version,
)
processor = CandidateProcessor(
@ -1775,6 +1777,7 @@ class FunctionOptimizer:
self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id,
ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None,
language=self.function_to_optimize.language,
language_version=self.language_support.language_version,
is_async=self.function_to_optimize.is_async,
n_candidates=n_candidates,
is_numerical_code=is_numerical_code,
@ -1801,6 +1804,7 @@ class FunctionOptimizer:
self.function_trace_id[:-4] + "EXP1",
ExperimentMetadata(id=self.experiment_id, group="experiment"),
language=self.function_to_optimize.language,
language_version=self.language_support.language_version,
is_async=self.function_to_optimize.is_async,
n_candidates=n_candidates,
)

View file

@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import module_name_from_file_path
from codeflash.languages.current import current_language_support
from codeflash.languages import current_language_support
from codeflash.verification.verification_utils import ModifyInspiredTests, delete_multiple_if_name_main
if TYPE_CHECKING:
@ -70,6 +70,7 @@ def generate_tests(
trace_id=function_trace_id,
test_index=test_index,
language=function_to_optimize.language,
language_version=current_language_support().language_version,
module_system=project_module_system,
is_numerical_code=is_numerical_code,
)

View file

@ -219,6 +219,7 @@ public class FibonacciTest {
}
}
"""
test_file.write_text(source)
func = FunctionToOptimize(
@ -2688,7 +2689,7 @@ public class CounterTest__perfinstrumented {
}
}
}
assertEquals(1, _cf_result1_1);
assertEquals(1, (int)_cf_result1_1);
}
}
"""