Merge remote-tracking branch 'origin/omni-java' into merge/misc-fixes-into-omni-java
# Conflicts: # codeflash/api/aiservice.py # codeflash/languages/base.py # codeflash/languages/java/support.py # codeflash/languages/javascript/support.py # codeflash/languages/python/support.py # codeflash/verification/verifier.py
This commit is contained in:
commit
2fb0145895
11 changed files with 141 additions and 65 deletions
|
|
@ -64,7 +64,7 @@ Core protocol in `languages/base.py`. Each language (`PythonSupport`, `JavaScrip
|
|||
|----------|----------------|---------|
|
||||
| Identity | `language`, `file_extensions`, `default_file_extension` | Language identification |
|
||||
| Identity | `comment_prefix`, `dir_excludes` | Language conventions |
|
||||
| AI service | `default_language_version` | Language version for API payloads (`None` for Python, `"ES2022"` for JS) |
|
||||
| AI service | `language_version` | Detected language version for API payloads (e.g., `"3.11.0"` for Python, `"17"` for Java) |
|
||||
| AI service | `valid_test_frameworks` | Allowed test frameworks for validation |
|
||||
| Discovery | `discover_functions`, `discover_tests` | Find optimizable functions and their tests |
|
||||
| Discovery | `adjust_test_config_for_discovery` | Pre-discovery config adjustment (no-op default) |
|
||||
|
|
|
|||
|
|
@ -57,10 +57,12 @@ class AiServiceClient:
|
|||
payload: dict[str, Any], language_version: str | None = None, module_system: str | None = None
|
||||
) -> None:
|
||||
"""Add language version and module system metadata to an API payload."""
|
||||
payload["python_version"] = platform.python_version()
|
||||
default_lang_version = current_language_support().default_language_version
|
||||
if default_lang_version is not None:
|
||||
payload["language_version"] = language_version or default_lang_version
|
||||
# Canonical for all languages
|
||||
payload["language_version"] = language_version
|
||||
# Backward compat: Python backend still expects python_version
|
||||
payload["python_version"] = language_version if current_language() == Language.PYTHON else None
|
||||
|
||||
if current_language() != Language.PYTHON:
|
||||
if module_system:
|
||||
payload["module_system"] = module_system
|
||||
|
||||
|
|
@ -142,8 +144,7 @@ class AiServiceClient:
|
|||
experiment_metadata: ExperimentMetadata | None = None,
|
||||
*,
|
||||
language: str = "python",
|
||||
language_version: str
|
||||
| None = None, # TODO:{claude} add language version to the language support and it should be cached
|
||||
language_version: str | None = None,
|
||||
module_system: str | None = None,
|
||||
is_async: bool = False,
|
||||
n_candidates: int = 5,
|
||||
|
|
@ -264,7 +265,7 @@ class AiServiceClient:
|
|||
"source_code": source_code,
|
||||
"trace_id": trace_id,
|
||||
"dependency_code": "", # dummy value to please the api endpoint
|
||||
"python_version": "3.12.1", # dummy value to please the api endpoint
|
||||
"python_version": platform.python_version(), # backward compat
|
||||
"current_username": get_last_commit_author_if_pr_exists(None),
|
||||
"repo_owner": git_repo_owner,
|
||||
"repo_name": git_repo_name,
|
||||
|
|
@ -331,18 +332,15 @@ class AiServiceClient:
|
|||
logger.info("Generating optimized candidates with line profiler…")
|
||||
console.rule()
|
||||
|
||||
# Set python_version for backward compatibility with Python, or use language_version
|
||||
python_version = language_version if language_version else platform.python_version()
|
||||
|
||||
payload = {
|
||||
"source_code": source_code,
|
||||
"dependency_code": dependency_code,
|
||||
"n_candidates": n_candidates,
|
||||
"line_profiler_results": line_profiler_results,
|
||||
"trace_id": trace_id,
|
||||
"python_version": python_version,
|
||||
"language": language,
|
||||
"language_version": language_version,
|
||||
"python_version": language_version if current_language() == Language.PYTHON else None,
|
||||
"experiment_metadata": experiment_metadata,
|
||||
"codeflash_version": codeflash_version,
|
||||
"call_sequence": self.get_next_sequence(),
|
||||
|
|
@ -644,7 +642,7 @@ class AiServiceClient:
|
|||
"diffs": diffs,
|
||||
"speedups": speedups,
|
||||
"optimization_ids": optimization_ids,
|
||||
"python_version": platform.python_version(),
|
||||
"python_version": platform.python_version(), # backward compat
|
||||
"function_references": function_references,
|
||||
}
|
||||
logger.info("loading|Generating ranking")
|
||||
|
|
@ -770,6 +768,8 @@ class AiServiceClient:
|
|||
"is_async": function_to_optimize.is_async,
|
||||
"call_sequence": self.get_next_sequence(),
|
||||
"is_numerical_code": is_numerical_code,
|
||||
"class_name": function_to_optimize.class_name,
|
||||
"qualified_name": function_to_optimize.qualified_name,
|
||||
}
|
||||
|
||||
self.add_language_metadata(payload, language_version, module_system)
|
||||
|
|
@ -858,6 +858,7 @@ class AiServiceClient:
|
|||
"codeflash_version": codeflash_version,
|
||||
"calling_fn_details": calling_fn_details,
|
||||
"language": language,
|
||||
"language_version": platform.python_version() if current_language() == Language.PYTHON else None,
|
||||
"python_version": platform.python_version() if current_language() == Language.PYTHON else None,
|
||||
"call_sequence": self.get_next_sequence(),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -137,14 +137,15 @@ class OptimizeRequest:
|
|||
"is_numerical_code": self.is_numerical_code,
|
||||
}
|
||||
|
||||
# Add language-specific fields
|
||||
if self.language_info.version:
|
||||
payload["language_version"] = self.language_info.version
|
||||
# Add language version (canonical for all languages)
|
||||
payload["language_version"] = self.language_info.version
|
||||
|
||||
# Backward compat: always include python_version
|
||||
# Backward compat: backend still expects python_version
|
||||
import platform
|
||||
|
||||
payload["python_version"] = platform.python_version()
|
||||
payload["python_version"] = (
|
||||
self.language_info.version if self.language_info.name == "python" else platform.python_version()
|
||||
)
|
||||
|
||||
# Module system for JS/TS
|
||||
if self.language_info.module_system != ModuleSystem.UNKNOWN:
|
||||
|
|
@ -205,14 +206,15 @@ class TestGenRequest:
|
|||
"is_numerical_code": self.is_numerical_code,
|
||||
}
|
||||
|
||||
# Add language version
|
||||
if self.language_info.version:
|
||||
payload["language_version"] = self.language_info.version
|
||||
# Add language version (canonical for all languages)
|
||||
payload["language_version"] = self.language_info.version
|
||||
|
||||
# Backward compat: always include python_version
|
||||
# Backward compat: backend still expects python_version
|
||||
import platform
|
||||
|
||||
payload["python_version"] = platform.python_version()
|
||||
payload["python_version"] = (
|
||||
self.language_info.version if self.language_info.name == "python" else platform.python_version()
|
||||
)
|
||||
|
||||
# Module system for JS/TS
|
||||
if self.language_info.module_system != ModuleSystem.UNKNOWN:
|
||||
|
|
|
|||
|
|
@ -326,12 +326,8 @@ class LanguageSupport(Protocol):
|
|||
...
|
||||
|
||||
@property
|
||||
def default_language_version(self) -> str | None:
|
||||
"""Default language version string sent to AI service.
|
||||
|
||||
Returns None for languages where the runtime version is auto-detected (e.g. Python).
|
||||
Returns a version string (e.g. "ES2022") for languages that need an explicit default.
|
||||
"""
|
||||
def language_version(self) -> str | None:
|
||||
"""The detected language version (e.g., "17" for Java, "ES2022" for JS)."""
|
||||
...
|
||||
|
||||
@property
|
||||
|
|
@ -900,6 +896,31 @@ class LanguageSupport(Protocol):
|
|||
"""
|
||||
return {}, ""
|
||||
|
||||
def run_line_profile_tests(
|
||||
self,
|
||||
test_paths: Any,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
line_profile_output_file: Path | None = None,
|
||||
) -> tuple[Path, Any]:
|
||||
"""Run tests for line profiling.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object containing test file information.
|
||||
test_env: Environment variables for the test run.
|
||||
cwd: Working directory for running tests.
|
||||
timeout: Optional timeout in seconds.
|
||||
project_root: Project root directory.
|
||||
line_profile_output_file: Path where line profile results will be written.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result).
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def run_behavioral_tests(
|
||||
self,
|
||||
test_paths: Any,
|
||||
|
|
@ -958,32 +979,6 @@ class LanguageSupport(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def run_line_profile_tests(
|
||||
self,
|
||||
test_paths: Any,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
line_profile_output_file: Path | None = None,
|
||||
) -> tuple[Path, Any]:
|
||||
"""Run tests for line profiling.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object containing test file information.
|
||||
test_env: Environment variables for the test run.
|
||||
cwd: Working directory for running tests.
|
||||
timeout: Optional timeout in seconds.
|
||||
project_root: Project root directory.
|
||||
line_profile_output_file: Path where line profile results will be written.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result).
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def convert_parents_to_tuple(parents: list | tuple) -> tuple[FunctionParent, ...]:
|
||||
"""Convert a list of parent objects to a tuple of FunctionParent.
|
||||
|
||||
|
|
|
|||
|
|
@ -258,6 +258,7 @@ def wrap_target_calls_with_treesitter(
|
|||
precise_call_timing: bool = False,
|
||||
class_name: str = "",
|
||||
test_method_name: str = "",
|
||||
target_return_type: str = "",
|
||||
) -> tuple[list[str], int]:
|
||||
"""Replace target method calls in body_lines with capture + serialize using tree-sitter.
|
||||
|
||||
|
|
@ -327,6 +328,8 @@ def wrap_target_calls_with_treesitter(
|
|||
call_counter += 1
|
||||
var_name = f"_cf_result{iter_id}_{call_counter}"
|
||||
cast_type = _infer_array_cast_type(body_line)
|
||||
if not cast_type and target_return_type and target_return_type != "void":
|
||||
cast_type = target_return_type
|
||||
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name
|
||||
|
||||
# Use per-call unique variables (with call_counter suffix) for behavior mode
|
||||
|
|
@ -524,6 +527,26 @@ def _infer_array_cast_type(line: str) -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
def _extract_return_type(function_to_optimize: Any) -> str:
|
||||
"""Extract the return type of a Java function from its source file using tree-sitter."""
|
||||
file_path = getattr(function_to_optimize, "file_path", None)
|
||||
func_name = _get_function_name(function_to_optimize)
|
||||
if not file_path or not file_path.exists():
|
||||
return ""
|
||||
try:
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
|
||||
analyzer = get_java_analyzer()
|
||||
source_text = file_path.read_text(encoding="utf-8")
|
||||
methods = analyzer.find_methods(source_text)
|
||||
for method in methods:
|
||||
if method.name == func_name and method.return_type:
|
||||
return method.return_type
|
||||
except Exception:
|
||||
logger.debug("Could not extract return type for %s", func_name)
|
||||
return ""
|
||||
|
||||
|
||||
def _get_qualified_name(func: Any) -> str:
|
||||
"""Get the qualified name from FunctionToOptimize."""
|
||||
if hasattr(func, "qualified_name"):
|
||||
|
|
@ -617,6 +640,7 @@ def instrument_existing_test(
|
|||
"""
|
||||
source = test_string
|
||||
func_name = _get_function_name(function_to_optimize)
|
||||
target_return_type = _extract_return_type(function_to_optimize)
|
||||
|
||||
# Get the original class name from the file name
|
||||
if test_path:
|
||||
|
|
@ -654,14 +678,16 @@ def instrument_existing_test(
|
|||
)
|
||||
else:
|
||||
# Behavior mode: add timing instrumentation that also writes to SQLite
|
||||
modified_source = _add_behavior_instrumentation(modified_source, original_class_name, func_name)
|
||||
modified_source = _add_behavior_instrumentation(
|
||||
modified_source, original_class_name, func_name, target_return_type
|
||||
)
|
||||
|
||||
logger.debug("Java %s testing for %s: renamed class %s -> %s", mode, func_name, original_class_name, new_class_name)
|
||||
# Why return True here?
|
||||
return True, modified_source
|
||||
|
||||
|
||||
def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) -> str:
|
||||
def _add_behavior_instrumentation(source: str, class_name: str, func_name: str, target_return_type: str = "") -> str:
|
||||
"""Add behavior instrumentation to test methods.
|
||||
|
||||
For behavior mode, this adds:
|
||||
|
|
@ -802,6 +828,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
|
|||
precise_call_timing=True,
|
||||
class_name=class_name,
|
||||
test_method_name=test_method_name,
|
||||
target_return_type=target_return_type,
|
||||
)
|
||||
|
||||
# Add behavior instrumentation setup code (shared variables for all calls in the method)
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ class JavaSupport(LanguageSupport):
|
|||
self._analyzer = get_java_analyzer()
|
||||
self.line_profiler_agent_arg: str | None = None
|
||||
self.line_profiler_warmup_iterations: int = 0
|
||||
self._language_version: str | None = None
|
||||
|
||||
@property
|
||||
def language(self) -> Language:
|
||||
|
|
@ -94,8 +95,8 @@ class JavaSupport(LanguageSupport):
|
|||
return frozenset({"target", "build", ".gradle", ".mvn", ".idea", "apidocs", "javadoc"})
|
||||
|
||||
@property
|
||||
def default_language_version(self) -> str | None:
|
||||
return "17"
|
||||
def language_version(self) -> str | None:
|
||||
return self._language_version
|
||||
|
||||
@property
|
||||
def valid_test_frameworks(self) -> tuple[str, ...]:
|
||||
|
|
@ -498,10 +499,39 @@ class JavaSupport(LanguageSupport):
|
|||
if config is None:
|
||||
return False
|
||||
|
||||
self._language_version = config.java_version
|
||||
if self._language_version is None:
|
||||
self._detect_java_version()
|
||||
|
||||
# For now, assume the runtime is available
|
||||
# A full implementation would check/install the JAR
|
||||
return True
|
||||
|
||||
def _detect_java_version(self) -> None:
|
||||
"""Detect and cache the Java runtime version."""
|
||||
if self._language_version is not None:
|
||||
return
|
||||
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
result = subprocess.run(["java", "-version"], check=False, capture_output=True, text=True, timeout=10)
|
||||
# java -version outputs to stderr, e.g. 'openjdk version "17.0.2"'
|
||||
output = result.stderr or result.stdout
|
||||
for line in output.splitlines():
|
||||
if "version" in line:
|
||||
# Extract version between quotes: "17.0.2" -> "17"
|
||||
start = line.find('"')
|
||||
end = line.find('"', start + 1)
|
||||
if start != -1 and end != -1:
|
||||
full_version = line[start + 1 : end]
|
||||
# Use major version only: "17.0.2" -> "17", "1.8.0_292" -> "8"
|
||||
major = full_version.split(".")[0]
|
||||
self._language_version = "8" if major == "1" else major
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def instrument_existing_test(
|
||||
self,
|
||||
test_path: Path,
|
||||
|
|
|
|||
|
|
@ -37,6 +37,9 @@ class JavaScriptSupport:
|
|||
using tree-sitter for code analysis and Jest for test execution.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._language_version: str | None = None
|
||||
|
||||
# === Properties ===
|
||||
|
||||
@property
|
||||
|
|
@ -69,8 +72,8 @@ class JavaScriptSupport:
|
|||
return frozenset({"node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache", ".turbo", ".vercel"})
|
||||
|
||||
@property
|
||||
def default_language_version(self) -> str | None:
|
||||
return "ES2022"
|
||||
def language_version(self) -> str | None:
|
||||
return self._language_version
|
||||
|
||||
@property
|
||||
def valid_test_frameworks(self) -> tuple[str, ...]:
|
||||
|
|
@ -2240,6 +2243,15 @@ class JavaScriptSupport:
|
|||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
def _detect_node_version(self) -> None:
|
||||
"""Detect and cache the Node.js runtime version."""
|
||||
try:
|
||||
result = subprocess.run(["node", "--version"], check=False, capture_output=True, text=True, timeout=10)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
self._language_version = result.stdout.strip().lstrip("v")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def ensure_runtime_environment(self, project_root: Path) -> bool:
|
||||
"""Ensure codeflash npm package is installed.
|
||||
|
||||
|
|
@ -2254,6 +2266,8 @@ class JavaScriptSupport:
|
|||
"""
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
||||
self._detect_node_version()
|
||||
|
||||
node_modules_pkg = project_root / "node_modules" / "codeflash"
|
||||
if node_modules_pkg.exists():
|
||||
logger.debug("codeflash already installed")
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
|
|
@ -180,8 +181,8 @@ class PythonSupport:
|
|||
)
|
||||
|
||||
@property
|
||||
def default_language_version(self) -> str | None:
|
||||
return None
|
||||
def language_version(self) -> str | None:
|
||||
return platform.python_version()
|
||||
|
||||
@property
|
||||
def valid_test_frameworks(self) -> tuple[str, ...]:
|
||||
|
|
|
|||
|
|
@ -1216,6 +1216,7 @@ class FunctionOptimizer:
|
|||
optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"],
|
||||
function_references=function_references,
|
||||
language=self.function_to_optimize.language,
|
||||
language_version=self.language_support.language_version,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
|
@ -1277,6 +1278,7 @@ class FunctionOptimizer:
|
|||
else None,
|
||||
is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts,
|
||||
language=self.function_to_optimize.language,
|
||||
language_version=self.language_support.language_version,
|
||||
)
|
||||
|
||||
processor = CandidateProcessor(
|
||||
|
|
@ -1775,6 +1777,7 @@ class FunctionOptimizer:
|
|||
self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id,
|
||||
ExperimentMetadata(id=self.experiment_id, group="control") if run_experiment else None,
|
||||
language=self.function_to_optimize.language,
|
||||
language_version=self.language_support.language_version,
|
||||
is_async=self.function_to_optimize.is_async,
|
||||
n_candidates=n_candidates,
|
||||
is_numerical_code=is_numerical_code,
|
||||
|
|
@ -1801,6 +1804,7 @@ class FunctionOptimizer:
|
|||
self.function_trace_id[:-4] + "EXP1",
|
||||
ExperimentMetadata(id=self.experiment_id, group="experiment"),
|
||||
language=self.function_to_optimize.language,
|
||||
language_version=self.language_support.language_version,
|
||||
is_async=self.function_to_optimize.is_async,
|
||||
n_candidates=n_candidates,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
|
|||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import module_name_from_file_path
|
||||
from codeflash.languages.current import current_language_support
|
||||
from codeflash.languages import current_language_support
|
||||
from codeflash.verification.verification_utils import ModifyInspiredTests, delete_multiple_if_name_main
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -70,6 +70,7 @@ def generate_tests(
|
|||
trace_id=function_trace_id,
|
||||
test_index=test_index,
|
||||
language=function_to_optimize.language,
|
||||
language_version=current_language_support().language_version,
|
||||
module_system=project_module_system,
|
||||
is_numerical_code=is_numerical_code,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -219,6 +219,7 @@ public class FibonacciTest {
|
|||
}
|
||||
}
|
||||
"""
|
||||
|
||||
test_file.write_text(source)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
|
|
@ -2688,7 +2689,7 @@ public class CounterTest__perfinstrumented {
|
|||
}
|
||||
}
|
||||
}
|
||||
assertEquals(1, _cf_result1_1);
|
||||
assertEquals(1, (int)_cf_result1_1);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue