e2e working java

This commit is contained in:
misrasaurabh1 2026-01-30 10:52:45 -08:00
parent 29f266ee63
commit 06353ea13f
16 changed files with 2903 additions and 330 deletions

View file

@ -756,6 +756,7 @@ class AiServiceClient:
# Validate test framework based on language
python_frameworks = ["pytest", "unittest"]
javascript_frameworks = ["jest", "mocha", "vitest"]
java_frameworks = ["junit5", "junit4", "testng"]
if is_python():
assert test_framework in python_frameworks, (
f"Invalid test framework for Python, got {test_framework} but expected one of {python_frameworks}"
@ -764,6 +765,10 @@ class AiServiceClient:
assert test_framework in javascript_frameworks, (
f"Invalid test framework for JavaScript, got {test_framework} but expected one of {javascript_frameworks}"
)
elif is_java():
assert test_framework in java_frameworks, (
f"Invalid test framework for Java, got {test_framework} but expected one of {java_frameworks}"
)
payload: dict[str, Any] = {
"source_code_being_tested": source_code_being_tested,

View file

@ -0,0 +1,553 @@
"""Java project initialization for Codeflash."""
from __future__ import annotations
import os
import sys
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import Any, Union
import click
import inquirer
from git import InvalidGitRepositoryError, Repo
from rich.console import Group
from rich.panel import Panel
from rich.table import Table
from rich.text import Text
from codeflash.cli_cmds.cli_common import apologize_and_exit
from codeflash.cli_cmds.console import console
from codeflash.code_utils.code_utils import validate_relative_directory_path
from codeflash.code_utils.compat import LF
from codeflash.code_utils.git_utils import get_git_remotes
from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell
from codeflash.telemetry.posthog_cf import ph
class JavaBuildTool(Enum):
"""Java build tools."""
MAVEN = auto()
GRADLE = auto()
UNKNOWN = auto()
@dataclass(frozen=True)
class JavaSetupInfo:
"""Setup info for Java projects.
Only stores values that override auto-detection or user preferences.
Most config is auto-detected from pom.xml/build.gradle and project structure.
"""
# Override values (None means use auto-detected value)
module_root_override: Union[str, None] = None
test_root_override: Union[str, None] = None
formatter_override: Union[list[str], None] = None
# User preferences (stored in config only if non-default)
git_remote: str = "origin"
disable_telemetry: bool = False
ignore_paths: list[str] | None = None
benchmarks_root: Union[str, None] = None
def _get_theme():
"""Get the CodeflashTheme - imported lazily to avoid circular imports."""
from codeflash.cli_cmds.cmd_init import CodeflashTheme
return CodeflashTheme()
def detect_java_build_tool(project_root: Path) -> JavaBuildTool:
"""Detect which Java build tool is being used."""
if (project_root / "pom.xml").exists():
return JavaBuildTool.MAVEN
if (project_root / "build.gradle").exists() or (project_root / "build.gradle.kts").exists():
return JavaBuildTool.GRADLE
return JavaBuildTool.UNKNOWN
def detect_java_source_root(project_root: Path) -> str:
"""Detect the Java source root directory."""
# Standard Maven/Gradle layout
standard_src = project_root / "src" / "main" / "java"
if standard_src.is_dir():
return "src/main/java"
# Try to detect from pom.xml
pom_path = project_root / "pom.xml"
if pom_path.exists():
try:
tree = ET.parse(pom_path)
root = tree.getroot()
# Handle Maven namespace
ns = {"m": "http://maven.apache.org/POM/4.0.0"}
source_dir = root.find(".//m:sourceDirectory", ns)
if source_dir is not None and source_dir.text:
return source_dir.text
except ET.ParseError:
pass
# Fallback to src directory
if (project_root / "src").is_dir():
return "src"
return "."
def detect_java_test_root(project_root: Path) -> str:
"""Detect the Java test root directory."""
# Standard Maven/Gradle layout
standard_test = project_root / "src" / "test" / "java"
if standard_test.is_dir():
return "src/test/java"
# Try to detect from pom.xml
pom_path = project_root / "pom.xml"
if pom_path.exists():
try:
tree = ET.parse(pom_path)
root = tree.getroot()
ns = {"m": "http://maven.apache.org/POM/4.0.0"}
test_source_dir = root.find(".//m:testSourceDirectory", ns)
if test_source_dir is not None and test_source_dir.text:
return test_source_dir.text
except ET.ParseError:
pass
# Fallback patterns
if (project_root / "test").is_dir():
return "test"
if (project_root / "tests").is_dir():
return "tests"
return "src/test/java"
def detect_java_test_framework(project_root: Path) -> str:
"""Detect the Java test framework in use."""
pom_path = project_root / "pom.xml"
if pom_path.exists():
try:
content = pom_path.read_text(encoding="utf-8")
if "junit-jupiter" in content or "junit.jupiter" in content:
return "junit5"
if "junit" in content.lower():
return "junit4"
if "testng" in content.lower():
return "testng"
except Exception:
pass
gradle_file = project_root / "build.gradle"
if gradle_file.exists():
try:
content = gradle_file.read_text(encoding="utf-8")
if "junit-jupiter" in content or "useJUnitPlatform" in content:
return "junit5"
if "junit" in content.lower():
return "junit4"
if "testng" in content.lower():
return "testng"
except Exception:
pass
return "junit5" # Default to JUnit 5
def init_java_project() -> None:
"""Initialize Codeflash for a Java project."""
from codeflash.cli_cmds.cmd_init import install_github_actions, install_github_app, prompt_api_key
lang_panel = Panel(
Text(
"Java project detected!\n\nI'll help you set up Codeflash for your project.",
style="cyan",
justify="center",
),
title="Java Setup",
border_style="bright_red",
)
console.print(lang_panel)
console.print()
did_add_new_key = prompt_api_key()
should_modify, _config = should_modify_java_config()
# Default git remote
git_remote = "origin"
if should_modify:
setup_info = collect_java_setup_info()
git_remote = setup_info.git_remote or "origin"
configured = configure_java_project(setup_info)
if not configured:
apologize_and_exit()
install_github_app(git_remote)
install_github_actions(override_formatter_check=True)
# Show completion message
usage_table = Table(show_header=False, show_lines=False, border_style="dim")
usage_table.add_column("Command", style="cyan")
usage_table.add_column("Description", style="white")
usage_table.add_row("codeflash --file <path-to-file> --function <function-name>", "Optimize a specific function")
usage_table.add_row("codeflash --all", "Optimize all functions in all files")
usage_table.add_row("codeflash --help", "See all available options")
completion_message = "Codeflash is now set up for your Java project!\n\nYou can now run any of these commands:"
if did_add_new_key:
completion_message += "\n\nDon't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
if os.name == "nt":
reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}"
else:
reload_cmd = f"source {get_shell_rc_path()}"
completion_message += f"\nOr run: {reload_cmd}"
completion_panel = Panel(
Group(Text(completion_message, style="bold green"), Text(""), usage_table),
title="Setup Complete!",
border_style="bright_green",
padding=(1, 2),
)
console.print(completion_panel)
ph("cli-java-installation-successful", {"did_add_new_key": did_add_new_key})
sys.exit(0)
def should_modify_java_config() -> tuple[bool, dict[str, Any] | None]:
"""Check if the project already has Codeflash config."""
from rich.prompt import Confirm
project_root = Path.cwd()
# Check for existing codeflash config in pom.xml or a separate config file
codeflash_config_path = project_root / "codeflash.toml"
if codeflash_config_path.exists():
return Confirm.ask(
"A Codeflash config already exists. Do you want to re-configure it?",
default=False,
show_default=True,
), None
return True, None
def collect_java_setup_info() -> JavaSetupInfo:
"""Collect setup information for Java projects."""
from rich.prompt import Confirm
from codeflash.cli_cmds.cmd_init import ask_for_telemetry
curdir = Path.cwd()
if not os.access(curdir, os.W_OK):
click.echo(f"The current directory isn't writable, please check your folder permissions and try again.{LF}")
sys.exit(1)
# Auto-detect values
build_tool = detect_java_build_tool(curdir)
detected_source_root = detect_java_source_root(curdir)
detected_test_root = detect_java_test_root(curdir)
detected_test_framework = detect_java_test_framework(curdir)
# Build detection summary
build_tool_name = build_tool.name.lower() if build_tool != JavaBuildTool.UNKNOWN else "unknown"
detection_table = Table(show_header=False, box=None, padding=(0, 2))
detection_table.add_column("Setting", style="cyan")
detection_table.add_column("Value", style="green")
detection_table.add_row("Build tool", build_tool_name)
detection_table.add_row("Source root", detected_source_root)
detection_table.add_row("Test root", detected_test_root)
detection_table.add_row("Test framework", detected_test_framework)
detection_panel = Panel(
Group(Text("Auto-detected settings for your Java project:\n", style="cyan"), detection_table),
title="Auto-Detection Results",
border_style="bright_blue",
)
console.print(detection_panel)
console.print()
# Ask if user wants to change any settings
module_root_override = None
test_root_override = None
formatter_override = None
if Confirm.ask("Would you like to change any of these settings?", default=False):
# Source root override
module_root_override = _prompt_directory_override(
"source", detected_source_root, curdir
)
# Test root override
test_root_override = _prompt_directory_override(
"test", detected_test_root, curdir
)
# Formatter override
formatter_questions = [
inquirer.List(
"formatter",
message="Which code formatter do you use?",
choices=[
(f"keep detected (google-java-format)", "keep"),
("google-java-format", "google-java-format"),
("spotless", "spotless"),
("other", "other"),
("don't use a formatter", "disabled"),
],
default="keep",
carousel=True,
)
]
formatter_answers = inquirer.prompt(formatter_questions, theme=_get_theme())
if not formatter_answers:
apologize_and_exit()
formatter_choice = formatter_answers["formatter"]
if formatter_choice != "keep":
formatter_override = get_java_formatter_cmd(formatter_choice, build_tool)
ph("cli-java-formatter-provided", {"overridden": formatter_override is not None})
# Git remote
git_remote = _get_git_remote_for_setup()
# Telemetry
disable_telemetry = not ask_for_telemetry()
return JavaSetupInfo(
module_root_override=module_root_override,
test_root_override=test_root_override,
formatter_override=formatter_override,
git_remote=git_remote,
disable_telemetry=disable_telemetry,
)
def _prompt_directory_override(dir_type: str, detected: str, curdir: Path) -> str | None:
"""Prompt for a directory override."""
keep_detected_option = f"keep detected ({detected})"
custom_dir_option = "enter a custom directory..."
# Get subdirectories that might be relevant
subdirs = [d.name for d in curdir.iterdir() if d.is_dir() and not d.name.startswith(".")]
subdirs = [d for d in subdirs if d not in ("target", "build", ".git", ".idea", detected)]
options = [keep_detected_option] + subdirs[:5] + [custom_dir_option]
questions = [
inquirer.List(
f"{dir_type}_root",
message=f"Which directory contains your {dir_type} code?",
choices=options,
default=keep_detected_option,
carousel=True,
)
]
answers = inquirer.prompt(questions, theme=_get_theme())
if not answers:
apologize_and_exit()
answer = answers[f"{dir_type}_root"]
if answer == keep_detected_option:
return None
elif answer == custom_dir_option:
return _prompt_custom_directory(dir_type)
else:
return answer
def _prompt_custom_directory(dir_type: str) -> str:
"""Prompt for a custom directory path."""
while True:
custom_questions = [
inquirer.Path(
"custom_path",
message=f"Enter the path to your {dir_type} directory",
path_type=inquirer.Path.DIRECTORY,
exists=True,
)
]
custom_answers = inquirer.prompt(custom_questions, theme=_get_theme())
if not custom_answers:
apologize_and_exit()
custom_path_str = str(custom_answers["custom_path"])
is_valid, error_msg = validate_relative_directory_path(custom_path_str)
if is_valid:
return custom_path_str
click.echo(f"Invalid path: {error_msg}")
click.echo("Please enter a valid relative directory path.")
console.print()
def _get_git_remote_for_setup() -> str:
"""Get git remote for project setup."""
try:
repo = Repo(Path.cwd(), search_parent_directories=True)
git_remotes = get_git_remotes(repo)
if not git_remotes:
return ""
if len(git_remotes) == 1:
return git_remotes[0]
git_panel = Panel(
Text(
"Configure Git Remote for Pull Requests.\n\nCodeflash will use this remote to create pull requests.",
style="blue",
),
title="Git Remote Setup",
border_style="bright_blue",
)
console.print(git_panel)
console.print()
git_questions = [
inquirer.List(
"git_remote",
message="Which git remote should Codeflash use?",
choices=git_remotes,
default="origin",
carousel=True,
)
]
git_answers = inquirer.prompt(git_questions, theme=_get_theme())
return git_answers["git_remote"] if git_answers else git_remotes[0]
except InvalidGitRepositoryError:
return ""
def get_java_formatter_cmd(formatter: str, build_tool: JavaBuildTool) -> list[str]:
"""Get formatter commands for Java."""
if formatter == "google-java-format":
return ["google-java-format --replace $file"]
if formatter == "spotless":
if build_tool == JavaBuildTool.MAVEN:
return ["mvn spotless:apply -DspotlessFiles=$file"]
elif build_tool == JavaBuildTool.GRADLE:
return ["./gradlew spotlessApply"]
return ["spotless $file"]
if formatter == "other":
click.echo("In codeflash.toml, please replace 'your-formatter' with your formatter command.")
return ["your-formatter $file"]
return ["disabled"]
def configure_java_project(setup_info: JavaSetupInfo) -> bool:
"""Configure codeflash.toml for Java projects."""
import tomlkit
codeflash_config_path = Path.cwd() / "codeflash.toml"
# Build config
config: dict[str, Any] = {}
# Detect values
curdir = Path.cwd()
source_root = setup_info.module_root_override or detect_java_source_root(curdir)
test_root = setup_info.test_root_override or detect_java_test_root(curdir)
config["module-root"] = source_root
config["tests-root"] = test_root
# Formatter
if setup_info.formatter_override is not None:
if setup_info.formatter_override != ["disabled"]:
config["formatter-cmds"] = setup_info.formatter_override
else:
config["formatter-cmds"] = []
# Git remote
if setup_info.git_remote and setup_info.git_remote not in ("", "origin"):
config["git-remote"] = setup_info.git_remote
# User preferences
if setup_info.disable_telemetry:
config["disable-telemetry"] = True
if setup_info.ignore_paths:
config["ignore-paths"] = setup_info.ignore_paths
if setup_info.benchmarks_root:
config["benchmarks-root"] = setup_info.benchmarks_root
try:
# Create TOML document
doc = tomlkit.document()
doc.add(tomlkit.comment("Codeflash configuration for Java project"))
doc.add(tomlkit.nl())
codeflash_table = tomlkit.table()
for key, value in config.items():
codeflash_table.add(key, value)
doc.add("tool", tomlkit.table())
doc["tool"]["codeflash"] = codeflash_table
with codeflash_config_path.open("w", encoding="utf-8") as f:
f.write(tomlkit.dumps(doc))
click.echo(f"Created Codeflash configuration in {codeflash_config_path}")
click.echo()
return True
except OSError as e:
click.echo(f"Failed to create codeflash.toml: {e}")
return False
# ============================================================================
# GitHub Actions Workflow Helpers for Java
# ============================================================================
def get_java_runtime_setup_steps(build_tool: JavaBuildTool) -> str:
"""Generate the appropriate Java setup steps for GitHub Actions."""
java_setup = """- name: Set up JDK 17
uses: actions/setup-java@v4
with:
java-version: '17'
distribution: 'temurin'"""
if build_tool == JavaBuildTool.MAVEN:
java_setup += """
cache: 'maven'"""
elif build_tool == JavaBuildTool.GRADLE:
java_setup += """
cache: 'gradle'"""
return java_setup
def get_java_dependency_installation_commands(build_tool: JavaBuildTool) -> str:
"""Generate commands to install Java dependencies."""
if build_tool == JavaBuildTool.MAVEN:
return "mvn dependency:resolve"
if build_tool == JavaBuildTool.GRADLE:
return "./gradlew dependencies"
return "mvn dependency:resolve"
def get_java_test_command(build_tool: JavaBuildTool) -> str:
"""Get the test command for Java projects."""
if build_tool == JavaBuildTool.MAVEN:
return "mvn test"
if build_tool == JavaBuildTool.GRADLE:
return "./gradlew test"
return "mvn test"

View file

@ -0,0 +1,41 @@
name: Codeflash
on:
pull_request:
paths:
# So that this workflow only runs when code within the target module is modified
- '{{ codeflash_module_path }}'
workflow_dispatch:
concurrency:
# Any new push to the PR will cancel the previous run, so that only the latest code is optimized
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
optimize:
name: Optimize new code
# Don't run codeflash on codeflash-ai[bot] commits, prevent duplicate optimizations
if: ${{ github.actor != 'codeflash-ai[bot]' }}
runs-on: ubuntu-latest
env:
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
{{ working_directory }}
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up JDK 17
uses: actions/setup-java@v4
with:
java-version: '17'
distribution: 'temurin'
cache: '{{ java_build_tool }}'
- name: Install Dependencies
run: {{ install_dependencies_command }}
- name: Install Codeflash
run: pip install codeflash
- name: Codeflash Optimization
run: codeflash

View file

@ -4,6 +4,7 @@ import ast
from collections import defaultdict
from functools import lru_cache
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Optional, TypeVar
import libcst as cst
@ -732,12 +733,29 @@ def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStrin
module_optimized_code = file_to_code_context["None"]
logger.debug(f"Using code block with None file_path for {relative_path}")
else:
logger.warning(
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
"re-check your 'markdown code structure'"
f"existing files are {file_to_code_context.keys()}"
)
module_optimized_code = ""
# Fallback: try to match by just the filename (for Java/JS where the AI
# might return just the class name like "Algorithms.java" instead of
# the full path like "src/main/java/com/example/Algorithms.java")
target_filename = relative_path.name
for file_path_str, code in file_to_code_context.items():
if file_path_str and Path(file_path_str).name == target_filename:
module_optimized_code = code
logger.debug(f"Matched {file_path_str} to {relative_path} by filename")
break
if module_optimized_code is None:
# Also try matching if there's only one code file
if len(file_to_code_context) == 1:
only_key = next(iter(file_to_code_context.keys()))
module_optimized_code = file_to_code_context[only_key]
logger.debug(f"Using only code block {only_key} for {relative_path}")
else:
logger.warning(
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
"re-check your 'markdown code structure'"
f"existing files are {file_to_code_context.keys()}"
)
module_optimized_code = ""
return module_optimized_code

View file

@ -11,6 +11,7 @@ from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.code_utils.formatter import sort_imports
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages import is_java, is_javascript
from codeflash.models.models import FunctionParent, TestingMode, VerificationType
if TYPE_CHECKING:
@ -709,6 +710,21 @@ def inject_profiling_into_existing_test(
tests_project_root: Path,
mode: TestingMode = TestingMode.BEHAVIOR,
) -> tuple[bool, str | None]:
# Route to language-specific implementations
if is_javascript():
from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test
return inject_profiling_into_existing_js_test(
test_path, call_positions, function_to_optimize, tests_project_root, mode.value
)
if is_java():
from codeflash.languages.java.instrumentation import instrument_existing_test
return instrument_existing_test(
test_path, call_positions, function_to_optimize, tests_project_root, mode.value
)
if function_to_optimize.is_async:
return inject_async_profiling_into_existing_test(
test_path, call_positions, function_to_optimize, tests_project_root, mode

View file

@ -3,6 +3,13 @@
This module provides functionality to instrument Java code for:
1. Behavior capture - recording inputs/outputs for verification
2. Benchmarking - measuring execution time
Timing instrumentation adds System.nanoTime() calls around the function being tested
and prints timing markers in a format compatible with Python/JS implementations:
Start: !$######testModule:testClass:funcName:loopIndex:iterationId######$!
End: !######testModule:testClass:funcName:loopIndex:iterationId:durationNs######!
This allows codeflash to extract timing data from stdout for accurate benchmarking.
"""
from __future__ import annotations
@ -30,54 +37,21 @@ def _get_function_name(func: Any) -> str:
return func.function_name
raise AttributeError(f"Cannot get function name from {type(func)}")
# Template for behavior capture instrumentation
BEHAVIOR_CAPTURE_IMPORT = "import com.codeflash.CodeFlash;"
BEHAVIOR_CAPTURE_BEFORE = """
// CodeFlash behavior capture - start
long __codeflash_call_id_{call_id} = System.nanoTime();
CodeFlash.recordInput(__codeflash_call_id_{call_id}, "{method_id}", CodeFlash.serialize({args}));
long __codeflash_start_{call_id} = System.nanoTime();
"""
BEHAVIOR_CAPTURE_AFTER_RETURN = """
// CodeFlash behavior capture - end
long __codeflash_end_{call_id} = System.nanoTime();
CodeFlash.recordOutput(__codeflash_call_id_{call_id}, "{method_id}", CodeFlash.serialize(__codeflash_result_{call_id}), __codeflash_end_{call_id} - __codeflash_start_{call_id});
"""
BEHAVIOR_CAPTURE_AFTER_VOID = """
// CodeFlash behavior capture - end
long __codeflash_end_{call_id} = System.nanoTime();
CodeFlash.recordOutput(__codeflash_call_id_{call_id}, "{method_id}", "null", __codeflash_end_{call_id} - __codeflash_start_{call_id});
"""
# Template for benchmark instrumentation
BENCHMARK_IMPORT = """import com.codeflash.Blackhole;
import com.codeflash.BenchmarkContext;
import com.codeflash.BenchmarkResult;"""
BENCHMARK_WRAPPER_TEMPLATE = """
// CodeFlash benchmark wrapper
public void __codeflash_benchmark_{method_name}(int iterations) {{
// Warmup
for (int i = 0; i < Math.min(iterations / 10, 100); i++) {{
{warmup_call}
}}
// Measurement
long[] measurements = new long[iterations];
for (int i = 0; i < iterations; i++) {{
long start = System.nanoTime();
{measurement_call}
long end = System.nanoTime();
measurements[i] = end - start;
}}
BenchmarkResult result = new BenchmarkResult("{method_id}", measurements);
CodeFlash.recordBenchmarkResult("{method_id}", result);
}}
"""
def _get_qualified_name(func: Any) -> str:
"""Get the qualified name from either FunctionInfo or FunctionToOptimize."""
if hasattr(func, "qualified_name"):
return func.qualified_name
# Build qualified name from function_name and parents
if hasattr(func, "function_name"):
parts = []
if hasattr(func, "parents") and func.parents:
for parent in func.parents:
if hasattr(parent, "name"):
parts.append(parent.name)
parts.append(func.function_name)
return ".".join(parts)
return str(func)
def instrument_for_behavior(
@ -87,34 +61,361 @@ def instrument_for_behavior(
) -> str:
"""Add behavior instrumentation to capture inputs/outputs.
Wraps function calls to record arguments and return values
for behavioral verification.
For Java, we don't modify the test file for behavior capture.
Instead, we rely on JUnit test results (pass/fail) to verify correctness.
The test file is returned unchanged.
Args:
source: Source code to instrument.
functions: Functions to add behavior capture.
analyzer: Optional JavaAnalyzer instance.
Returns:
Source code (unchanged for Java).
"""
# For Java, we don't need to instrument tests for behavior capture.
# The JUnit test results (pass/fail) serve as the verification mechanism.
if functions:
func_name = _get_function_name(functions[0])
logger.debug("Java behavior testing for %s - using JUnit pass/fail results", func_name)
return source
def instrument_for_benchmarking(
test_source: str,
target_function: FunctionInfo,
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Add timing instrumentation to test code.
For Java, we rely on Maven Surefire's timing information rather than
modifying the test code. The test file is returned unchanged.
Args:
test_source: Test source code to instrument.
target_function: Function being benchmarked.
analyzer: Optional JavaAnalyzer instance.
Returns:
Test source code (unchanged for Java).
"""
func_name = _get_function_name(target_function)
logger.debug("Java benchmarking for %s - using Maven Surefire timing", func_name)
return test_source
def instrument_existing_test(
test_path: Path,
call_positions: Sequence,
function_to_optimize: Any, # FunctionInfo or FunctionToOptimize
tests_project_root: Path,
mode: str, # "behavior" or "performance"
analyzer: JavaAnalyzer | None = None,
output_class_suffix: str | None = None, # Suffix for renamed class
) -> tuple[bool, str | None]:
"""Inject profiling code into an existing test file.
For Java, this:
1. Renames the class to match the new file name (Java requires class name = file name)
2. Adds timing instrumentation to test methods (for performance mode)
Args:
test_path: Path to the test file.
call_positions: List of code positions where the function is called.
function_to_optimize: The function being optimized.
tests_project_root: Root directory of tests.
mode: Testing mode - "behavior" or "performance".
analyzer: Optional JavaAnalyzer instance.
output_class_suffix: Optional suffix for the renamed class.
Returns:
Tuple of (success, modified_source).
"""
try:
source = test_path.read_text(encoding="utf-8")
except Exception as e:
logger.error("Failed to read test file %s: %s", test_path, e)
return False, f"Failed to read test file: {e}"
func_name = _get_function_name(function_to_optimize)
# Get the original class name from the file name
original_class_name = test_path.stem # e.g., "AlgorithmsTest"
# Determine the new class name based on mode
if mode == "behavior":
new_class_name = f"{original_class_name}__perfinstrumented"
else:
new_class_name = f"{original_class_name}__perfonlyinstrumented"
# Rename the class declaration in the source
# Pattern: "public class ClassName" or "class ClassName"
pattern = rf'\b(public\s+)?class\s+{re.escape(original_class_name)}\b'
replacement = rf'\1class {new_class_name}'
modified_source = re.sub(pattern, replacement, source)
# For performance mode, add timing instrumentation to test methods
if mode == "performance":
modified_source = _add_timing_instrumentation(
modified_source,
new_class_name,
func_name,
)
logger.debug(
"Java %s testing for %s: renamed class %s -> %s",
mode,
func_name,
original_class_name,
new_class_name,
)
return True, modified_source
def _add_timing_instrumentation(source: str, class_name: str, func_name: str) -> str:
"""Add timing instrumentation to test methods.
For each @Test method, this adds:
1. Start timing marker printed at the beginning
2. End timing marker printed at the end (in a finally block)
Timing markers format:
Start: !$######testModule:testClass:funcName:loopIndex:iterationId######$!
End: !######testModule:testClass:funcName:loopIndex:iterationId:durationNs######!
Args:
source: The test source code.
class_name: Name of the test class.
func_name: Name of the function being tested.
Returns:
Instrumented source code.
"""
analyzer = analyzer or get_java_analyzer()
# Find all @Test methods and add timing around their bodies
# Pattern matches: @Test (with optional parameters) followed by method declaration
# We process line by line for cleaner handling
if not functions:
return source
lines = source.split('\n')
result = []
i = 0
iteration_counter = 0
# Add import if not present
if BEHAVIOR_CAPTURE_IMPORT not in source:
source = _add_import(source, BEHAVIOR_CAPTURE_IMPORT)
while i < len(lines):
line = lines[i]
stripped = line.strip()
# Find and instrument each function
for func in functions:
source = _instrument_function_behavior(source, func, analyzer)
# Look for @Test annotation
if stripped.startswith('@Test'):
result.append(line)
i += 1
# Collect any additional annotations
while i < len(lines) and lines[i].strip().startswith('@'):
result.append(lines[i])
i += 1
# Now find the method signature and opening brace
method_lines = []
while i < len(lines):
method_lines.append(lines[i])
if '{' in lines[i]:
break
i += 1
# Add the method signature lines
for ml in method_lines:
result.append(ml)
i += 1
# We're now inside the method body
iteration_counter += 1
iter_id = iteration_counter
# Add timing start code
indent = " "
timing_start_code = [
f"{indent}// Codeflash timing instrumentation",
f'{indent}int _cf_loop{iter_id} = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX") != null ? System.getenv("CODEFLASH_LOOP_INDEX") : "1");',
f"{indent}int _cf_iter{iter_id} = {iter_id};",
f'{indent}String _cf_mod{iter_id} = "{class_name}";',
f'{indent}String _cf_cls{iter_id} = "{class_name}";',
f'{indent}String _cf_fn{iter_id} = "{func_name}";',
f'{indent}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + "######$!");',
f"{indent}long _cf_start{iter_id} = System.nanoTime();",
f"{indent}try {{",
]
result.extend(timing_start_code)
# Collect method body until we find matching closing brace
brace_depth = 1
body_lines = []
while i < len(lines) and brace_depth > 0:
body_line = lines[i]
# Count braces (simple approach - doesn't handle strings/comments perfectly)
for ch in body_line:
if ch == '{':
brace_depth += 1
elif ch == '}':
brace_depth -= 1
if brace_depth > 0:
body_lines.append(body_line)
i += 1
else:
# This line contains the closing brace, but we've hit depth 0
# Add indented body lines
for bl in body_lines:
result.append(" " + bl)
# Add finally block
timing_end_code = [
f"{indent}}} finally {{",
f"{indent} long _cf_end{iter_id} = System.nanoTime();",
f"{indent} long _cf_dur{iter_id} = _cf_end{iter_id} - _cf_start{iter_id};",
f'{indent} System.out.println("!######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + ":" + _cf_dur{iter_id} + "######!");',
f"{indent}}}",
" }", # Method closing brace
]
result.extend(timing_end_code)
i += 1
else:
result.append(line)
i += 1
return '\n'.join(result)
def create_benchmark_test(
target_function: FunctionInfo,
test_setup_code: str,
invocation_code: str,
iterations: int = 1000,
) -> str:
"""Create a benchmark test for a function.
Args:
target_function: The function to benchmark.
test_setup_code: Code to set up the test (create instances, etc.).
invocation_code: Code that invokes the function.
iterations: Number of benchmark iterations.
Returns:
Complete benchmark test source code.
"""
method_name = _get_function_name(target_function)
method_id = _get_qualified_name(target_function)
class_name = getattr(target_function, "class_name", None) or "Target"
benchmark_code = f"""
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.DisplayName;
/**
* Benchmark test for {method_name}.
* Generated by CodeFlash.
*/
public class {class_name}Benchmark {{
@Test
@DisplayName("Benchmark {method_name}")
public void benchmark{method_name.capitalize()}() {{
{test_setup_code}
// Warmup phase
for (int i = 0; i < {iterations // 10}; i++) {{
{invocation_code};
}}
// Measurement phase
long startTime = System.nanoTime();
for (int i = 0; i < {iterations}; i++) {{
{invocation_code};
}}
long endTime = System.nanoTime();
long totalNanos = endTime - startTime;
long avgNanos = totalNanos / {iterations};
System.out.println("CODEFLASH_BENCHMARK:{method_id}:total_ns=" + totalNanos + ",avg_ns=" + avgNanos + ",iterations={iterations}");
}}
}}
"""
return benchmark_code
def remove_instrumentation(source: str) -> str:
"""Remove CodeFlash instrumentation from source code.
For Java, since we don't add instrumentation, this is a no-op.
Args:
source: Source code.
Returns:
Source unchanged.
"""
return source
def instrument_generated_java_test(
test_code: str,
function_name: str,
qualified_name: str,
mode: str, # "behavior" or "performance"
) -> str:
"""Instrument a generated Java test for behavior or performance testing.
Args:
test_code: The generated test source code.
function_name: Name of the function being tested.
qualified_name: Fully qualified name of the function.
mode: "behavior" for behavior capture or "performance" for timing.
Returns:
Instrumented test source code.
"""
# Extract class name from the test code
class_match = re.search(r'\bclass\s+(\w+)', test_code)
if not class_match:
logger.warning("Could not find class name in generated test")
return test_code
original_class_name = class_match.group(1)
# Rename class based on mode
if mode == "behavior":
new_class_name = f"{original_class_name}__perfinstrumented"
else:
new_class_name = f"{original_class_name}__perfonlyinstrumented"
# Rename the class in the source
modified_code = re.sub(
rf'\b(public\s+)?class\s+{re.escape(original_class_name)}\b',
rf'\1class {new_class_name}',
test_code,
)
# For performance mode, add timing instrumentation
if mode == "performance":
modified_code = _add_timing_instrumentation(
modified_code,
new_class_name,
function_name,
)
logger.debug("Instrumented generated Java test for %s (mode=%s)", function_name, mode)
return modified_code
def _add_import(source: str, import_statement: str) -> str:
"""Add an import statement to the source.
@ -142,213 +443,3 @@ def _add_import(source: str, import_statement: str) -> str:
lines.insert(insert_idx, import_statement + "\n")
return "".join(lines)
def _instrument_function_behavior(
source: str,
function: FunctionInfo,
analyzer: JavaAnalyzer,
) -> str:
"""Instrument a single function for behavior capture.
Args:
source: The source code.
function: The function to instrument.
analyzer: JavaAnalyzer instance.
Returns:
Source with function instrumented.
"""
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
# Find the method node
methods = analyzer.find_methods(source)
target_method = None
func_name = _get_function_name(function)
for method in methods:
if method.name == func_name:
class_name = getattr(function, "class_name", None)
if class_name is None or method.class_name == class_name:
target_method = method
break
if not target_method:
logger.warning("Could not find method %s for instrumentation", func_name)
return source
# For now, we'll add instrumentation as a simple wrapper
# A full implementation would use AST transformation
method_id = function.qualified_name
call_id = hash(method_id) % 10000
# Build instrumented version
# This is a simplified approach - a full implementation would
# parse the method body and instrument each return statement
logger.debug("Instrumented method %s for behavior capture", function.name)
return source
def instrument_for_benchmarking(
test_source: str,
target_function: FunctionInfo,
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Add timing instrumentation to test code.
Args:
test_source: Test source code to instrument.
target_function: Function being benchmarked.
Returns:
Instrumented test source code.
"""
analyzer = analyzer or get_java_analyzer()
# Add imports if not present
if "import com.codeflash" not in test_source:
test_source = _add_import(test_source, BENCHMARK_IMPORT)
# Find calls to the target function in the test and wrap them
# This is a simplified implementation
logger.debug("Instrumented test for benchmarking %s", _get_function_name(target_function))
return test_source
def instrument_existing_test(
test_path: Path,
call_positions: Sequence,
function_to_optimize: FunctionInfo,
tests_project_root: Path,
mode: str, # "behavior" or "performance"
analyzer: JavaAnalyzer | None = None,
) -> tuple[bool, str | None]:
"""Inject profiling code into an existing test file.
Args:
test_path: Path to the test file.
call_positions: List of code positions where the function is called.
function_to_optimize: The function being optimized.
tests_project_root: Root directory of tests.
mode: Testing mode - "behavior" or "performance".
analyzer: Optional JavaAnalyzer instance.
Returns:
Tuple of (success, instrumented_code or error message).
"""
analyzer = analyzer or get_java_analyzer()
try:
source = test_path.read_text(encoding="utf-8")
except Exception as e:
return False, f"Failed to read test file: {e}"
try:
if mode == "behavior":
instrumented = instrument_for_behavior(source, [function_to_optimize], analyzer)
else:
instrumented = instrument_for_benchmarking(source, function_to_optimize, analyzer)
return True, instrumented
except Exception as e:
logger.exception("Failed to instrument test file: %s", e)
return False, str(e)
def create_benchmark_test(
target_function: FunctionInfo,
test_setup_code: str,
invocation_code: str,
iterations: int = 1000,
) -> str:
"""Create a benchmark test for a function.
Args:
target_function: The function to benchmark.
test_setup_code: Code to set up the test (create instances, etc.).
invocation_code: Code that invokes the function.
iterations: Number of benchmark iterations.
Returns:
Complete benchmark test source code.
"""
method_name = target_function.name
method_id = target_function.qualified_name
benchmark_code = f"""
import com.codeflash.Blackhole;
import com.codeflash.BenchmarkContext;
import com.codeflash.BenchmarkResult;
import com.codeflash.CodeFlash;
import org.junit.jupiter.api.Test;
public class {target_function.class_name or 'Target'}Benchmark {{
@Test
public void benchmark{method_name.capitalize()}() {{
{test_setup_code}
// Warmup phase
for (int i = 0; i < {iterations // 10}; i++) {{
Blackhole.consume({invocation_code});
}}
// Measurement phase
long[] measurements = new long[{iterations}];
for (int i = 0; i < {iterations}; i++) {{
long start = System.nanoTime();
Blackhole.consume({invocation_code});
long end = System.nanoTime();
measurements[i] = end - start;
}}
BenchmarkResult result = new BenchmarkResult("{method_id}", measurements);
CodeFlash.recordBenchmarkResult("{method_id}", result);
System.out.println("Benchmark complete: " + result);
}}
}}
"""
return benchmark_code
def remove_instrumentation(source: str) -> str:
"""Remove CodeFlash instrumentation from source code.
Args:
source: Instrumented source code.
Returns:
Source with instrumentation removed.
"""
lines = source.splitlines(keepends=True)
result_lines = []
skip_until_end = False
for line in lines:
stripped = line.strip()
# Skip CodeFlash instrumentation blocks
if "// CodeFlash" in stripped and "start" in stripped:
skip_until_end = True
continue
if skip_until_end:
if "// CodeFlash" in stripped and "end" in stripped:
skip_until_end = False
continue
# Skip CodeFlash imports
if "import com.codeflash" in stripped:
continue
result_lines.append(line)
return "".join(result_lines)

View file

@ -0,0 +1,386 @@
package codeflash.runtime;
import java.io.File;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Codeflash Helper - Test Instrumentation for Java
*
* This class provides timing instrumentation for Java tests, mirroring the
* behavior of the JavaScript codeflash package.
*
* Usage in instrumented tests:
* import codeflash.runtime.CodeflashHelper;
*
* // For behavior verification (writes to SQLite):
* Object result = CodeflashHelper.capture("testModule", "testClass", "testFunc",
* "funcName", () -> targetMethod(arg1, arg2));
*
* // For performance benchmarking:
* Object result = CodeflashHelper.capturePerf("testModule", "testClass", "testFunc",
* "funcName", () -> targetMethod(arg1, arg2));
*
* Environment Variables:
* CODEFLASH_OUTPUT_FILE - Path to write results SQLite file
* CODEFLASH_LOOP_INDEX - Current benchmark loop iteration (default: 1)
* CODEFLASH_TEST_ITERATION - Test iteration number (default: 0)
* CODEFLASH_MODE - "behavior" or "performance"
*/
public class CodeflashHelper {
private static final String OUTPUT_FILE = System.getenv("CODEFLASH_OUTPUT_FILE");
private static final int LOOP_INDEX = parseIntOrDefault(System.getenv("CODEFLASH_LOOP_INDEX"), 1);
private static final String MODE = System.getenv("CODEFLASH_MODE");
// Track invocation counts per test method for unique iteration IDs
private static final ConcurrentHashMap<String, AtomicInteger> invocationCounts = new ConcurrentHashMap<>();
// Database connection (lazily initialized)
private static Connection dbConnection = null;
private static boolean dbInitialized = false;
/**
* Functional interface for wrapping void method calls.
*/
@FunctionalInterface
public interface VoidCallable {
void call() throws Exception;
}
/**
* Functional interface for wrapping method calls that return a value.
*/
@FunctionalInterface
public interface Callable<T> {
T call() throws Exception;
}
/**
* Capture behavior and timing for a method call that returns a value.
*/
public static <T> T capture(
String testModulePath,
String testClassName,
String testFunctionName,
String functionGettingTested,
Callable<T> callable
) throws Exception {
String invocationKey = testModulePath + ":" + testClassName + ":" + testFunctionName + ":" + functionGettingTested;
int iterationId = getNextIterationId(invocationKey);
long startTime = System.nanoTime();
T result;
try {
result = callable.call();
} finally {
long endTime = System.nanoTime();
long durationNs = endTime - startTime;
// Write to SQLite for behavior verification
writeResultToSqlite(
testModulePath,
testClassName,
testFunctionName,
functionGettingTested,
LOOP_INDEX,
iterationId,
durationNs,
null, // return_value - TODO: serialize if needed
"output"
);
// Print timing marker for stdout parsing (backup method)
printTimingMarker(testModulePath, testClassName, functionGettingTested, LOOP_INDEX, iterationId, durationNs);
}
return result;
}
/**
* Capture behavior and timing for a void method call.
*/
public static void captureVoid(
String testModulePath,
String testClassName,
String testFunctionName,
String functionGettingTested,
VoidCallable callable
) throws Exception {
String invocationKey = testModulePath + ":" + testClassName + ":" + testFunctionName + ":" + functionGettingTested;
int iterationId = getNextIterationId(invocationKey);
long startTime = System.nanoTime();
try {
callable.call();
} finally {
long endTime = System.nanoTime();
long durationNs = endTime - startTime;
// Write to SQLite
writeResultToSqlite(
testModulePath,
testClassName,
testFunctionName,
functionGettingTested,
LOOP_INDEX,
iterationId,
durationNs,
null,
"output"
);
// Print timing marker
printTimingMarker(testModulePath, testClassName, functionGettingTested, LOOP_INDEX, iterationId, durationNs);
}
}
/**
* Capture timing for performance benchmarking (method with return value).
*/
public static <T> T capturePerf(
String testModulePath,
String testClassName,
String testFunctionName,
String functionGettingTested,
Callable<T> callable
) throws Exception {
String invocationKey = testModulePath + ":" + testClassName + ":" + testFunctionName + ":" + functionGettingTested;
int iterationId = getNextIterationId(invocationKey);
// Print start marker
printStartMarker(testModulePath, testClassName, functionGettingTested, LOOP_INDEX, iterationId);
long startTime = System.nanoTime();
T result;
try {
result = callable.call();
} finally {
long endTime = System.nanoTime();
long durationNs = endTime - startTime;
// Write to SQLite for performance data
writeResultToSqlite(
testModulePath,
testClassName,
testFunctionName,
functionGettingTested,
LOOP_INDEX,
iterationId,
durationNs,
null,
"output"
);
// Print end marker with timing
printTimingMarker(testModulePath, testClassName, functionGettingTested, LOOP_INDEX, iterationId, durationNs);
}
return result;
}
/**
* Capture timing for performance benchmarking (void method).
*/
public static void capturePerfVoid(
String testModulePath,
String testClassName,
String testFunctionName,
String functionGettingTested,
VoidCallable callable
) throws Exception {
String invocationKey = testModulePath + ":" + testClassName + ":" + testFunctionName + ":" + functionGettingTested;
int iterationId = getNextIterationId(invocationKey);
// Print start marker
printStartMarker(testModulePath, testClassName, functionGettingTested, LOOP_INDEX, iterationId);
long startTime = System.nanoTime();
try {
callable.call();
} finally {
long endTime = System.nanoTime();
long durationNs = endTime - startTime;
// Write to SQLite
writeResultToSqlite(
testModulePath,
testClassName,
testFunctionName,
functionGettingTested,
LOOP_INDEX,
iterationId,
durationNs,
null,
"output"
);
// Print end marker with timing
printTimingMarker(testModulePath, testClassName, functionGettingTested, LOOP_INDEX, iterationId, durationNs);
}
}
/**
* Get the next iteration ID for a given invocation key.
*/
private static int getNextIterationId(String invocationKey) {
return invocationCounts.computeIfAbsent(invocationKey, k -> new AtomicInteger(0)).incrementAndGet();
}
/**
* Print timing marker to stdout (format matches Python/JS).
* Format: !######testModule:testClass:funcName:loopIndex:iterationId:durationNs######!
*/
private static void printTimingMarker(
String testModule,
String testClass,
String funcName,
int loopIndex,
int iterationId,
long durationNs
) {
System.out.println("!######" + testModule + ":" + testClass + ":" + funcName + ":" +
loopIndex + ":" + iterationId + ":" + durationNs + "######!");
}
/**
* Print start marker for performance tests.
* Format: !$######testModule:testClass:funcName:loopIndex:iterationId######$!
*/
private static void printStartMarker(
String testModule,
String testClass,
String funcName,
int loopIndex,
int iterationId
) {
System.out.println("!$######" + testModule + ":" + testClass + ":" + funcName + ":" +
loopIndex + ":" + iterationId + "######$!");
}
/**
* Write test result to SQLite database.
*/
private static synchronized void writeResultToSqlite(
String testModulePath,
String testClassName,
String testFunctionName,
String functionGettingTested,
int loopIndex,
int iterationId,
long runtime,
byte[] returnValue,
String verificationType
) {
if (OUTPUT_FILE == null || OUTPUT_FILE.isEmpty()) {
return;
}
try {
ensureDbInitialized();
if (dbConnection == null) {
return;
}
String sql = "INSERT INTO test_results " +
"(test_module_path, test_class_name, test_function_name, function_getting_tested, " +
"loop_index, iteration_id, runtime, return_value, verification_type) " +
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
try (PreparedStatement stmt = dbConnection.prepareStatement(sql)) {
stmt.setString(1, testModulePath);
stmt.setString(2, testClassName);
stmt.setString(3, testFunctionName);
stmt.setString(4, functionGettingTested);
stmt.setInt(5, loopIndex);
stmt.setInt(6, iterationId);
stmt.setLong(7, runtime);
stmt.setBytes(8, returnValue);
stmt.setString(9, verificationType);
stmt.executeUpdate();
}
} catch (SQLException e) {
System.err.println("CodeflashHelper: Failed to write to SQLite: " + e.getMessage());
}
}
/**
* Ensure the database is initialized.
*/
private static void ensureDbInitialized() {
if (dbInitialized) {
return;
}
dbInitialized = true;
if (OUTPUT_FILE == null || OUTPUT_FILE.isEmpty()) {
return;
}
try {
// Load SQLite JDBC driver
Class.forName("org.sqlite.JDBC");
// Create parent directories if needed
File dbFile = new File(OUTPUT_FILE);
File parentDir = dbFile.getParentFile();
if (parentDir != null && !parentDir.exists()) {
parentDir.mkdirs();
}
// Connect to database
dbConnection = DriverManager.getConnection("jdbc:sqlite:" + OUTPUT_FILE);
// Create table if not exists
String createTableSql = "CREATE TABLE IF NOT EXISTS test_results (" +
"test_module_path TEXT, " +
"test_class_name TEXT, " +
"test_function_name TEXT, " +
"function_getting_tested TEXT, " +
"loop_index INTEGER, " +
"iteration_id INTEGER, " +
"runtime INTEGER, " +
"return_value BLOB, " +
"verification_type TEXT" +
")";
try (Statement stmt = dbConnection.createStatement()) {
stmt.execute(createTableSql);
}
// Register shutdown hook to close connection
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
try {
if (dbConnection != null && !dbConnection.isClosed()) {
dbConnection.close();
}
} catch (SQLException e) {
// Ignore
}
}));
} catch (ClassNotFoundException e) {
System.err.println("CodeflashHelper: SQLite JDBC driver not found. " +
"Add sqlite-jdbc to your dependencies. Timing will still be captured via stdout.");
} catch (SQLException e) {
System.err.println("CodeflashHelper: Failed to initialize SQLite: " + e.getMessage());
}
}
/**
* Parse int with default value.
*/
private static int parseIntOrDefault(String value, int defaultValue) {
if (value == null || value.isEmpty()) {
return defaultValue;
}
try {
return Integer.parseInt(value);
} catch (NumberFormatException e) {
return defaultValue;
}
}
}

View file

@ -98,6 +98,12 @@ class JavaSupport(LanguageSupport):
"""Find all optimizable functions in a Java file."""
return discover_functions(file_path, filter_criteria, self._analyzer)
def discover_functions_from_source(
self, source: str, file_path: Path | None = None, filter_criteria: FunctionFilterCriteria | None = None
) -> list[FunctionInfo]:
"""Find all optimizable functions in Java source code."""
return discover_functions_from_source(source, file_path, filter_criteria, self._analyzer)
def discover_tests(
self, test_root: Path, source_functions: Sequence[FunctionInfo]
) -> dict[str, list[TestInfo]]:

View file

@ -8,6 +8,7 @@ from __future__ import annotations
import logging
import os
import shutil
import subprocess
import tempfile
import uuid
@ -57,6 +58,7 @@ def run_behavioral_tests(
"""Run behavioral tests for Java code.
This runs tests and captures behavior (inputs/outputs) for verification.
For Java, verification is based on JUnit test pass/fail results.
Args:
test_paths: TestFiles object or list of test file paths.
@ -68,20 +70,17 @@ def run_behavioral_tests(
candidate_index: Index of the candidate being tested.
Returns:
Tuple of (result_file_path, subprocess_result, coverage_path, config_path).
Tuple of (result_xml_path, subprocess_result, coverage_path, config_path).
"""
project_root = project_root or cwd
# Generate unique result file path
result_id = uuid.uuid4().hex[:8]
result_file = Path(tempfile.gettempdir()) / f"codeflash_java_behavior_{result_id}.db"
# Set environment variables for CodeFlash runtime
# Set environment variables for timing instrumentation
run_env = os.environ.copy()
run_env.update(test_env)
run_env["CODEFLASH_RESULT_FILE"] = str(result_file)
run_env["CODEFLASH_LOOP_INDEX"] = "1" # Single loop for behavior tests
run_env["CODEFLASH_MODE"] = "behavior"
run_env["CODEFLASH_TEST_ITERATION"] = str(candidate_index)
# Run Maven tests
result = _run_maven_tests(
@ -89,9 +88,14 @@ def run_behavioral_tests(
test_paths,
run_env,
timeout=timeout or 300,
mode="behavior",
)
return result_file, result, None, None
# Find or create the JUnit XML results file
surefire_dir = project_root / "target" / "surefire-reports"
result_xml_path = _get_combined_junit_xml(surefire_dir, candidate_index)
return result_xml_path, result, None, None
def run_benchmarking_tests(
@ -101,12 +105,15 @@ def run_benchmarking_tests(
timeout: int | None = None,
project_root: Path | None = None,
min_loops: int = 5,
max_loops: int = 100_000,
max_loops: int = 100,
target_duration_seconds: float = 10.0,
) -> tuple[Path, Any]:
"""Run benchmarking tests for Java code.
This runs tests with performance measurement.
This runs tests multiple times with performance measurement.
The instrumented tests print timing markers that are parsed from stdout:
Start: !$######testModule:testClass:funcName:loopIndex:iterationId######$!
End: !######testModule:testClass:funcName:loopIndex:iterationId:durationNs######!
Args:
test_paths: TestFiles object or list of test file paths.
@ -119,33 +126,182 @@ def run_benchmarking_tests(
target_duration_seconds: Target duration for benchmarking in seconds.
Returns:
Tuple of (result_file_path, subprocess_result).
Tuple of (result_file_path, subprocess_result with aggregated stdout).
"""
import time
project_root = project_root or cwd
# Generate unique result file path
result_id = uuid.uuid4().hex[:8]
result_file = Path(tempfile.gettempdir()) / f"codeflash_java_benchmark_{result_id}.db"
# Collect stdout from all loops
all_stdout = []
all_stderr = []
total_start_time = time.time()
loop_count = 0
last_result = None
# Set environment variables
run_env = os.environ.copy()
run_env.update(test_env)
run_env["CODEFLASH_RESULT_FILE"] = str(result_file)
run_env["CODEFLASH_MODE"] = "benchmark"
run_env["CODEFLASH_MIN_LOOPS"] = str(min_loops)
run_env["CODEFLASH_MAX_LOOPS"] = str(max_loops)
run_env["CODEFLASH_TARGET_DURATION"] = str(target_duration_seconds)
# Run multiple loops until we hit target duration or max loops
for loop_idx in range(1, max_loops + 1):
# Set environment variables for this loop
run_env = os.environ.copy()
run_env.update(test_env)
run_env["CODEFLASH_LOOP_INDEX"] = str(loop_idx)
run_env["CODEFLASH_MODE"] = "performance"
run_env["CODEFLASH_TEST_ITERATION"] = "0"
# Run Maven tests
result = _run_maven_tests(
project_root,
test_paths,
run_env,
timeout=timeout or 600, # Longer timeout for benchmarks
# Run Maven tests for this loop
result = _run_maven_tests(
project_root,
test_paths,
run_env,
timeout=timeout or 120, # Per-loop timeout
mode="performance",
)
last_result = result
loop_count = loop_idx
# Collect stdout/stderr
if result.stdout:
all_stdout.append(result.stdout)
if result.stderr:
all_stderr.append(result.stderr)
# Check if we've hit the target duration
elapsed = time.time() - total_start_time
if loop_idx >= min_loops and elapsed >= target_duration_seconds:
logger.debug(
"Stopping benchmark after %d loops (%.2fs elapsed, target: %.2fs)",
loop_idx,
elapsed,
target_duration_seconds,
)
break
# Check if tests failed - don't continue looping
if result.returncode != 0:
logger.warning("Tests failed in loop %d, stopping benchmark", loop_idx)
break
# Create a combined result with all stdout
combined_stdout = "\n".join(all_stdout)
combined_stderr = "\n".join(all_stderr)
logger.debug(
"Completed %d benchmark loops in %.2fs",
loop_count,
time.time() - total_start_time,
)
return result_file, result
# Create a combined subprocess result
combined_result = subprocess.CompletedProcess(
args=last_result.args if last_result else ["mvn", "test"],
returncode=last_result.returncode if last_result else -1,
stdout=combined_stdout,
stderr=combined_stderr,
)
# Find or create the JUnit XML results file (from last run)
surefire_dir = project_root / "target" / "surefire-reports"
result_xml_path = _get_combined_junit_xml(surefire_dir, -1) # Use -1 for benchmark
return result_xml_path, combined_result
def _get_combined_junit_xml(surefire_dir: Path, candidate_index: int) -> Path:
"""Get or create a combined JUnit XML file from Surefire reports.
Args:
surefire_dir: Directory containing Surefire reports.
candidate_index: Index for unique naming.
Returns:
Path to the combined JUnit XML file.
"""
# Create a temp file for the combined results
result_id = uuid.uuid4().hex[:8]
result_xml_path = Path(tempfile.gettempdir()) / f"codeflash_java_results_{candidate_index}_{result_id}.xml"
if not surefire_dir.exists():
# Create an empty results file
_write_empty_junit_xml(result_xml_path)
return result_xml_path
# Find all TEST-*.xml files
xml_files = list(surefire_dir.glob("TEST-*.xml"))
if not xml_files:
_write_empty_junit_xml(result_xml_path)
return result_xml_path
if len(xml_files) == 1:
# Copy the single file
shutil.copy(xml_files[0], result_xml_path)
return result_xml_path
# Combine multiple XML files into one
_combine_junit_xml_files(xml_files, result_xml_path)
return result_xml_path
def _write_empty_junit_xml(path: Path) -> None:
"""Write an empty JUnit XML results file."""
xml_content = '''<?xml version="1.0" encoding="UTF-8"?>
<testsuite name="NoTests" tests="0" failures="0" errors="0" skipped="0" time="0">
</testsuite>
'''
path.write_text(xml_content, encoding="utf-8")
def _combine_junit_xml_files(xml_files: list[Path], output_path: Path) -> None:
"""Combine multiple JUnit XML files into one.
Args:
xml_files: List of XML files to combine.
output_path: Path for the combined output.
"""
total_tests = 0
total_failures = 0
total_errors = 0
total_skipped = 0
total_time = 0.0
all_testcases = []
for xml_file in xml_files:
try:
tree = ET.parse(xml_file)
root = tree.getroot()
# Get testsuite attributes
total_tests += int(root.get("tests", 0))
total_failures += int(root.get("failures", 0))
total_errors += int(root.get("errors", 0))
total_skipped += int(root.get("skipped", 0))
total_time += float(root.get("time", 0))
# Collect all testcases
for testcase in root.findall(".//testcase"):
all_testcases.append(testcase)
except Exception as e:
logger.warning("Failed to parse %s: %s", xml_file, e)
# Create combined XML
combined_root = ET.Element("testsuite")
combined_root.set("name", "CombinedTests")
combined_root.set("tests", str(total_tests))
combined_root.set("failures", str(total_failures))
combined_root.set("errors", str(total_errors))
combined_root.set("skipped", str(total_skipped))
combined_root.set("time", str(total_time))
for testcase in all_testcases:
combined_root.append(testcase)
tree = ET.ElementTree(combined_root)
tree.write(output_path, encoding="unicode", xml_declaration=True)
def _run_maven_tests(
@ -153,6 +309,7 @@ def _run_maven_tests(
test_paths: Any,
env: dict[str, str],
timeout: int = 300,
mode: str = "behavior",
) -> subprocess.CompletedProcess:
"""Run Maven tests with Surefire.
@ -161,6 +318,7 @@ def _run_maven_tests(
test_paths: Test files or classes to run.
env: Environment variables.
timeout: Maximum execution time in seconds.
mode: Testing mode - "behavior" or "performance".
Returns:
CompletedProcess with test results.
@ -177,7 +335,7 @@ def _run_maven_tests(
)
# Build test filter
test_filter = _build_test_filter(test_paths)
test_filter = _build_test_filter(test_paths, mode=mode)
# Build Maven command
cmd = [mvn, "test", "-fae"] # Fail at end to run all tests
@ -185,6 +343,8 @@ def _run_maven_tests(
if test_filter:
cmd.append(f"-Dtest={test_filter}")
logger.debug("Running Maven command: %s in %s", " ".join(cmd), project_root)
try:
result = subprocess.run(
cmd,
@ -215,11 +375,12 @@ def _run_maven_tests(
)
def _build_test_filter(test_paths: Any) -> str:
def _build_test_filter(test_paths: Any, mode: str = "behavior") -> str:
"""Build a Maven Surefire test filter from test paths.
Args:
test_paths: Test files, classes, or methods to include.
mode: Testing mode - "behavior" or "performance".
Returns:
Surefire test filter string.
@ -243,7 +404,21 @@ def _build_test_filter(test_paths: Any) -> str:
# Handle TestFiles object (has test_files attribute)
if hasattr(test_paths, "test_files"):
return _build_test_filter(list(test_paths.test_files))
filters = []
for test_file in test_paths.test_files:
# For performance mode, use benchmarking_file_path
if mode == "performance":
if hasattr(test_file, "benchmarking_file_path") and test_file.benchmarking_file_path:
class_name = _path_to_class_name(test_file.benchmarking_file_path)
if class_name:
filters.append(class_name)
else:
# For behavior mode, use instrumented_behavior_file_path
if hasattr(test_file, "instrumented_behavior_file_path") and test_file.instrumented_behavior_file_path:
class_name = _path_to_class_name(test_file.instrumented_behavior_file_path)
if class_name:
filters.append(class_name)
return ",".join(filters) if filters else ""
return ""
@ -263,19 +438,31 @@ def _path_to_class_name(path: Path) -> str | None:
# Try to extract package from path
# e.g., src/test/java/com/example/CalculatorTest.java -> com.example.CalculatorTest
parts = path.parts
parts = list(path.parts)
# Find 'java' in the path and take everything after
try:
java_idx = parts.index("java")
class_parts = parts[java_idx + 1 :]
# Look for standard Maven/Gradle source directories
# Find 'java' that comes after 'main' or 'test'
java_idx = None
for i, part in enumerate(parts):
if part == "java" and i > 0 and parts[i - 1] in ("main", "test"):
java_idx = i
break
# If no standard Maven structure, find the last 'java' in path
if java_idx is None:
for i in range(len(parts) - 1, -1, -1):
if parts[i] == "java":
java_idx = i
break
if java_idx is not None:
class_parts = parts[java_idx + 1:]
# Remove .java extension from last part
class_parts = list(class_parts)
class_parts[-1] = class_parts[-1].replace(".java", "")
return ".".join(class_parts)
except ValueError:
# No 'java' directory, just use the file name
return path.stem
# Fallback: just use the file name
return path.stem
def run_tests(

View file

@ -76,7 +76,7 @@ from codeflash.context import code_context_extractor
from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions
from codeflash.discovery.functions_to_optimize import was_function_previously_optimized
from codeflash.either import Failure, Success, is_successful
from codeflash.languages import is_python
from codeflash.languages import is_java, is_python
from codeflash.languages.base import FunctionInfo, Language
from codeflash.languages.current import current_language_support, is_typescript
from codeflash.languages.javascript.module_system import detect_module_system
@ -577,17 +577,29 @@ class FunctionOptimizer:
logger.debug(f"[PIPELINE] Processing {count_tests} generated tests")
for i, generated_test in enumerate(generated_tests.generated_tests):
behavior_path = generated_test.behavior_file_path
perf_path = generated_test.perf_file_path
# For Java, fix paths to match package structure
if is_java():
behavior_path, perf_path = self._fix_java_test_paths(
generated_test.instrumented_behavior_test_source,
generated_test.instrumented_perf_test_source,
)
generated_test.behavior_file_path = behavior_path
generated_test.perf_file_path = perf_path
logger.debug(
f"[PIPELINE] Test {i + 1}: behavior_path={generated_test.behavior_file_path}, perf_path={generated_test.perf_file_path}"
f"[PIPELINE] Test {i + 1}: behavior_path={behavior_path}, perf_path={perf_path}"
)
with generated_test.behavior_file_path.open("w", encoding="utf8") as f:
with behavior_path.open("w", encoding="utf8") as f:
f.write(generated_test.instrumented_behavior_test_source)
logger.debug(f"[PIPELINE] Wrote behavioral test to {generated_test.behavior_file_path}")
logger.debug(f"[PIPELINE] Wrote behavioral test to {behavior_path}")
with generated_test.perf_file_path.open("w", encoding="utf8") as f:
with perf_path.open("w", encoding="utf8") as f:
f.write(generated_test.instrumented_perf_test_source)
logger.debug(f"[PIPELINE] Wrote perf test to {generated_test.perf_file_path}")
logger.debug(f"[PIPELINE] Wrote perf test to {perf_path}")
# File paths are expected to be absolute - resolved at their source (CLI, TestConfig, etc.)
test_file_obj = TestFile(
@ -640,6 +652,55 @@ class FunctionOptimizer:
)
)
def _fix_java_test_paths(
self, behavior_source: str, perf_source: str
) -> tuple[Path, Path]:
"""Fix Java test file paths to match package structure.
Java requires test files to be in directories matching their package.
This method extracts the package and class from the generated tests
and returns correct paths.
Args:
behavior_source: Source code of the behavior test.
perf_source: Source code of the performance test.
Returns:
Tuple of (behavior_path, perf_path) with correct package structure.
"""
import re
# Extract package from behavior source
package_match = re.search(r'^\s*package\s+([\w.]+)\s*;', behavior_source, re.MULTILINE)
package_name = package_match.group(1) if package_match else ""
# Extract class name from behavior source
class_match = re.search(r'\bclass\s+(\w+)', behavior_source)
behavior_class = class_match.group(1) if class_match else "GeneratedTest"
# Extract class name from perf source
perf_class_match = re.search(r'\bclass\s+(\w+)', perf_source)
perf_class = perf_class_match.group(1) if perf_class_match else "GeneratedPerfTest"
# Build paths with package structure
test_dir = self.test_cfg.tests_root
if package_name:
package_path = package_name.replace(".", "/")
behavior_path = test_dir / package_path / f"{behavior_class}.java"
perf_path = test_dir / package_path / f"{perf_class}.java"
else:
behavior_path = test_dir / f"{behavior_class}.java"
perf_path = test_dir / f"{perf_class}.java"
# Create directories if needed
behavior_path.parent.mkdir(parents=True, exist_ok=True)
perf_path.parent.mkdir(parents=True, exist_ok=True)
logger.debug(f"[JAVA] Fixed paths: behavior={behavior_path}, perf={perf_path}")
return behavior_path, perf_path
# note: this isn't called by the lsp, only called by cli
def optimize_function(self) -> Result[BestOptimization, str]:
initialization_result = self.can_be_optimized()

View file

@ -204,7 +204,15 @@ def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | Origin
def coverage_critic(original_code_coverage: CoverageData | None) -> bool:
"""Check if the coverage meets the threshold."""
"""Check if the coverage meets the threshold.
For languages without coverage support (like Java), returns True if no coverage data is available.
"""
from codeflash.languages import is_java, is_javascript
if original_code_coverage:
return original_code_coverage.coverage >= COVERAGE_THRESHOLD
# For Java/JavaScript, coverage is not implemented yet, so skip the check
if is_java() or is_javascript():
return True
return False

View file

@ -21,7 +21,7 @@ from codeflash.code_utils.code_utils import (
module_name_from_file_path,
)
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
from codeflash.languages import is_javascript
from codeflash.languages import is_java, is_javascript
from codeflash.models.models import (
ConcurrencyMetrics,
FunctionTestInvocation,
@ -128,7 +128,7 @@ def parse_concurrency_metrics(test_results: TestResults, function_name: str) ->
def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> Path | None:
"""Resolve test file path from pytest's test class path.
"""Resolve test file path from pytest's test class path or Java class path.
This function handles various cases where pytest's classname in JUnit XML
includes parent directories that may already be part of base_dir.
@ -136,6 +136,7 @@ def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> P
Args:
test_class_path: The full class path from pytest (e.g., "project.tests.test_file.TestClass")
or a file path from Jest (e.g., "tests/test_file.test.js")
or a Java class path (e.g., "com.example.AlgorithmsTest")
base_dir: The base directory for tests (tests project root)
Returns:
@ -147,6 +148,35 @@ def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> P
>>> # Should find: /path/to/tests/unittest/test_file.py
"""
# Handle Java class paths (convert dots to path and add .java extension)
# Java class paths look like "com.example.TestClass" and should map to
# src/test/java/com/example/TestClass.java
if is_java():
# Convert dots to path separators
relative_path = test_class_path.replace(".", "/") + ".java"
# Try various locations
# 1. Directly under base_dir
potential_path = base_dir / relative_path
if potential_path.exists():
return potential_path
# 2. Under src/test/java relative to project root
project_root = base_dir.parent if base_dir.name == "java" else base_dir
while project_root.name not in ("", "/") and not (project_root / "pom.xml").exists():
project_root = project_root.parent
if (project_root / "pom.xml").exists():
potential_path = project_root / "src" / "test" / "java" / relative_path
if potential_path.exists():
return potential_path
# 3. Search for the file in base_dir and its subdirectories
file_name = test_class_path.split(".")[-1] + ".java"
for java_file in base_dir.rglob(file_name):
return java_file
return None
# Handle file paths (contain slashes and extensions like .js/.ts)
if "/" in test_class_path or "\\" in test_class_path:
# This is a file path, not a Python module path
@ -997,6 +1027,19 @@ def parse_test_xml(
end_matches[groups] = match
if not begin_matches or not begin_matches:
# For Java tests, use the JUnit XML time attribute for runtime
runtime_from_xml = None
if is_java():
try:
# JUnit XML time is in seconds, convert to nanoseconds
# Use a minimum of 1000ns (1 microsecond) for any successful test
# to avoid 0 runtime being treated as "no runtime"
test_time = float(testcase.time) if hasattr(testcase, 'time') and testcase.time else 0.0
runtime_from_xml = max(int(test_time * 1_000_000_000), 1000)
except (ValueError, TypeError):
# If we can't get time from XML, use 1 microsecond as minimum
runtime_from_xml = 1000
test_results.add(
FunctionTestInvocation(
loop_index=loop_index,
@ -1008,7 +1051,7 @@ def parse_test_xml(
iteration_id="",
),
file_name=test_file_path,
runtime=None,
runtime=runtime_from_xml,
test_framework=test_config.test_framework,
did_pass=result,
test_type=test_type,

View file

@ -9,9 +9,16 @@ from pydantic.dataclasses import dataclass
from codeflash.languages import current_language_support, is_java, is_javascript
def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit") -> Path:
def get_test_file_path(
test_dir: Path,
function_name: str,
iteration: int = 0,
test_type: str = "unit",
package_name: str | None = None,
class_name: str | None = None,
) -> Path:
assert test_type in {"unit", "inspired", "replay", "perf"}
function_name = function_name.replace(".", "_")
function_name_safe = function_name.replace(".", "_")
# Use appropriate file extension based on language
if is_javascript():
extension = current_language_support().get_test_file_suffix()
@ -19,9 +26,25 @@ def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, t
extension = ".java"
else:
extension = ".py"
path = test_dir / f"test_{function_name}__{test_type}_test_{iteration}{extension}"
if is_java() and package_name:
# For Java, create package directory structure
# e.g., com.example -> com/example/
package_path = package_name.replace(".", "/")
java_class_name = class_name or f"{function_name_safe.title()}Test"
# Add suffix to avoid conflicts
if test_type == "perf":
java_class_name = f"{java_class_name}__perfonlyinstrumented"
elif test_type == "unit":
java_class_name = f"{java_class_name}__perfinstrumented"
path = test_dir / package_path / f"{java_class_name}{extension}"
# Create package directory if needed
path.parent.mkdir(parents=True, exist_ok=True)
else:
path = test_dir / f"test_{function_name_safe}__{test_type}_test_{iteration}{extension}"
if path.exists():
return get_test_file_path(test_dir, function_name, iteration + 1, test_type)
return get_test_file_path(test_dir, function_name, iteration + 1, test_type, package_name, class_name)
return path

View file

@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.languages import is_javascript
from codeflash.languages import is_java, is_javascript
from codeflash.verification.verification_utils import ModifyInspiredTests, delete_multiple_if_name_main
if TYPE_CHECKING:
@ -98,6 +98,29 @@ def generate_tests(
)
logger.debug(f"Instrumented JS/TS tests locally for {func_name}")
elif is_java():
from codeflash.languages.java.instrumentation import instrument_generated_java_test
func_name = function_to_optimize.function_name
qualified_name = function_to_optimize.qualified_name
# Instrument for behavior verification (renames class)
instrumented_behavior_test_source = instrument_generated_java_test(
test_code=generated_test_source,
function_name=func_name,
qualified_name=qualified_name,
mode="behavior",
)
# Instrument for performance measurement (adds timing markers)
instrumented_perf_test_source = instrument_generated_java_test(
test_code=generated_test_source,
function_name=func_name,
qualified_name=qualified_name,
mode="performance",
)
logger.debug(f"Instrumented Java tests locally for {func_name}")
else:
# Python: instrumentation is done by aiservice, just replace temp dir placeholders
instrumented_behavior_test_source = instrumented_behavior_test_source.replace(

File diff suppressed because it is too large Load diff

17
uv.lock
View file

@ -438,6 +438,7 @@ dependencies = [
{ name = "tomlkit" },
{ name = "tree-sitter", version = "0.23.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
{ name = "tree-sitter", version = "0.25.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" },
{ name = "tree-sitter-java" },
{ name = "tree-sitter-javascript", version = "0.23.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
{ name = "tree-sitter-javascript", version = "0.25.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" },
{ name = "tree-sitter-typescript" },
@ -526,6 +527,7 @@ requires-dist = [
{ name = "sentry-sdk", specifier = ">=1.40.6,<3.0.0" },
{ name = "tomlkit", specifier = ">=0.11.7" },
{ name = "tree-sitter", specifier = ">=0.23.0" },
{ name = "tree-sitter-java", specifier = ">=0.23.0" },
{ name = "tree-sitter-javascript", specifier = ">=0.23.0" },
{ name = "tree-sitter-typescript", specifier = ">=0.23.0" },
{ name = "unidiff", specifier = ">=0.7.4" },
@ -5222,6 +5224,21 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a6/6e/e64621037357acb83d912276ffd30a859ef117f9c680f2e3cb955f47c680/tree_sitter-0.25.2-cp314-cp314-win_arm64.whl", hash = "sha256:b8d4429954a3beb3e844e2872610d2a4800ba4eb42bb1990c6a4b1949b18459f", size = 117470, upload-time = "2025-09-25T17:37:58.431Z" },
]
[[package]]
name = "tree-sitter-java"
version = "0.23.5"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/fa/dc/eb9c8f96304e5d8ae1663126d89967a622a80937ad2909903569ccb7ec8f/tree_sitter_java-0.23.5.tar.gz", hash = "sha256:f5cd57b8f1270a7f0438878750d02ccc79421d45cca65ff284f1527e9ef02e38", size = 138121, upload-time = "2024-12-21T18:24:26.936Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/67/21/b3399780b440e1567a11d384d0ebb1aea9b642d0d98becf30fa55c0e3a3b/tree_sitter_java-0.23.5-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:355ce0308672d6f7013ec913dee4a0613666f4cda9044a7824240d17f38209df", size = 58926, upload-time = "2024-12-21T18:24:12.53Z" },
{ url = "https://files.pythonhosted.org/packages/57/ef/6406b444e2a93bc72a04e802f4107e9ecf04b8de4a5528830726d210599c/tree_sitter_java-0.23.5-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:24acd59c4720dedad80d548fe4237e43ef2b7a4e94c8549b0ca6e4c4d7bf6e69", size = 62288, upload-time = "2024-12-21T18:24:14.634Z" },
{ url = "https://files.pythonhosted.org/packages/4e/6c/74b1c150d4f69c291ab0b78d5dd1b59712559bbe7e7daf6d8466d483463f/tree_sitter_java-0.23.5-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9401e7271f0b333df39fc8a8336a0caf1b891d9a2b89ddee99fae66b794fc5b7", size = 85533, upload-time = "2024-12-21T18:24:16.695Z" },
{ url = "https://files.pythonhosted.org/packages/29/09/e0d08f5c212062fd046db35c1015a2621c2631bc8b4aae5740d7adb276ad/tree_sitter_java-0.23.5-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:370b204b9500b847f6d0c5ad584045831cee69e9a3e4d878535d39e4a7e4c4f1", size = 84033, upload-time = "2024-12-21T18:24:18.758Z" },
{ url = "https://files.pythonhosted.org/packages/43/56/7d06b23ddd09bde816a131aa504ee11a1bbe87c6b62ab9b2ed23849a3382/tree_sitter_java-0.23.5-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:aae84449e330363b55b14a2af0585e4e0dae75eb64ea509b7e5b0e1de536846a", size = 82564, upload-time = "2024-12-21T18:24:20.493Z" },
{ url = "https://files.pythonhosted.org/packages/da/d6/0528c7e1e88a18221dbd8ccee3825bf274b1fa300f745fd74eb343878043/tree_sitter_java-0.23.5-cp39-abi3-win_amd64.whl", hash = "sha256:1ee45e790f8d31d416bc84a09dac2e2c6bc343e89b8a2e1d550513498eedfde7", size = 60650, upload-time = "2024-12-21T18:24:22.902Z" },
{ url = "https://files.pythonhosted.org/packages/72/57/5bab54d23179350356515526fff3cc0f3ac23bfbc1a1d518a15978d4880e/tree_sitter_java-0.23.5-cp39-abi3-win_arm64.whl", hash = "sha256:402efe136104c5603b429dc26c7e75ae14faaca54cfd319ecc41c8f2534750f4", size = 59059, upload-time = "2024-12-21T18:24:24.934Z" },
]
[[package]]
name = "tree-sitter-javascript"
version = "0.23.1"