e2e working java
This commit is contained in:
parent
29f266ee63
commit
06353ea13f
16 changed files with 2903 additions and 330 deletions
|
|
@ -756,6 +756,7 @@ class AiServiceClient:
|
|||
# Validate test framework based on language
|
||||
python_frameworks = ["pytest", "unittest"]
|
||||
javascript_frameworks = ["jest", "mocha", "vitest"]
|
||||
java_frameworks = ["junit5", "junit4", "testng"]
|
||||
if is_python():
|
||||
assert test_framework in python_frameworks, (
|
||||
f"Invalid test framework for Python, got {test_framework} but expected one of {python_frameworks}"
|
||||
|
|
@ -764,6 +765,10 @@ class AiServiceClient:
|
|||
assert test_framework in javascript_frameworks, (
|
||||
f"Invalid test framework for JavaScript, got {test_framework} but expected one of {javascript_frameworks}"
|
||||
)
|
||||
elif is_java():
|
||||
assert test_framework in java_frameworks, (
|
||||
f"Invalid test framework for Java, got {test_framework} but expected one of {java_frameworks}"
|
||||
)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"source_code_being_tested": source_code_being_tested,
|
||||
|
|
|
|||
553
codeflash/cli_cmds/init_java.py
Normal file
553
codeflash/cli_cmds/init_java.py
Normal file
|
|
@ -0,0 +1,553 @@
|
|||
"""Java project initialization for Codeflash."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import xml.etree.ElementTree as ET
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
import click
|
||||
import inquirer
|
||||
from git import InvalidGitRepositoryError, Repo
|
||||
from rich.console import Group
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from codeflash.cli_cmds.cli_common import apologize_and_exit
|
||||
from codeflash.cli_cmds.console import console
|
||||
from codeflash.code_utils.code_utils import validate_relative_directory_path
|
||||
from codeflash.code_utils.compat import LF
|
||||
from codeflash.code_utils.git_utils import get_git_remotes
|
||||
from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
|
||||
|
||||
class JavaBuildTool(Enum):
|
||||
"""Java build tools."""
|
||||
|
||||
MAVEN = auto()
|
||||
GRADLE = auto()
|
||||
UNKNOWN = auto()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JavaSetupInfo:
|
||||
"""Setup info for Java projects.
|
||||
|
||||
Only stores values that override auto-detection or user preferences.
|
||||
Most config is auto-detected from pom.xml/build.gradle and project structure.
|
||||
"""
|
||||
|
||||
# Override values (None means use auto-detected value)
|
||||
module_root_override: Union[str, None] = None
|
||||
test_root_override: Union[str, None] = None
|
||||
formatter_override: Union[list[str], None] = None
|
||||
|
||||
# User preferences (stored in config only if non-default)
|
||||
git_remote: str = "origin"
|
||||
disable_telemetry: bool = False
|
||||
ignore_paths: list[str] | None = None
|
||||
benchmarks_root: Union[str, None] = None
|
||||
|
||||
|
||||
def _get_theme():
|
||||
"""Get the CodeflashTheme - imported lazily to avoid circular imports."""
|
||||
from codeflash.cli_cmds.cmd_init import CodeflashTheme
|
||||
|
||||
return CodeflashTheme()
|
||||
|
||||
|
||||
def detect_java_build_tool(project_root: Path) -> JavaBuildTool:
|
||||
"""Detect which Java build tool is being used."""
|
||||
if (project_root / "pom.xml").exists():
|
||||
return JavaBuildTool.MAVEN
|
||||
if (project_root / "build.gradle").exists() or (project_root / "build.gradle.kts").exists():
|
||||
return JavaBuildTool.GRADLE
|
||||
return JavaBuildTool.UNKNOWN
|
||||
|
||||
|
||||
def detect_java_source_root(project_root: Path) -> str:
|
||||
"""Detect the Java source root directory."""
|
||||
# Standard Maven/Gradle layout
|
||||
standard_src = project_root / "src" / "main" / "java"
|
||||
if standard_src.is_dir():
|
||||
return "src/main/java"
|
||||
|
||||
# Try to detect from pom.xml
|
||||
pom_path = project_root / "pom.xml"
|
||||
if pom_path.exists():
|
||||
try:
|
||||
tree = ET.parse(pom_path)
|
||||
root = tree.getroot()
|
||||
# Handle Maven namespace
|
||||
ns = {"m": "http://maven.apache.org/POM/4.0.0"}
|
||||
source_dir = root.find(".//m:sourceDirectory", ns)
|
||||
if source_dir is not None and source_dir.text:
|
||||
return source_dir.text
|
||||
except ET.ParseError:
|
||||
pass
|
||||
|
||||
# Fallback to src directory
|
||||
if (project_root / "src").is_dir():
|
||||
return "src"
|
||||
|
||||
return "."
|
||||
|
||||
|
||||
def detect_java_test_root(project_root: Path) -> str:
|
||||
"""Detect the Java test root directory."""
|
||||
# Standard Maven/Gradle layout
|
||||
standard_test = project_root / "src" / "test" / "java"
|
||||
if standard_test.is_dir():
|
||||
return "src/test/java"
|
||||
|
||||
# Try to detect from pom.xml
|
||||
pom_path = project_root / "pom.xml"
|
||||
if pom_path.exists():
|
||||
try:
|
||||
tree = ET.parse(pom_path)
|
||||
root = tree.getroot()
|
||||
ns = {"m": "http://maven.apache.org/POM/4.0.0"}
|
||||
test_source_dir = root.find(".//m:testSourceDirectory", ns)
|
||||
if test_source_dir is not None and test_source_dir.text:
|
||||
return test_source_dir.text
|
||||
except ET.ParseError:
|
||||
pass
|
||||
|
||||
# Fallback patterns
|
||||
if (project_root / "test").is_dir():
|
||||
return "test"
|
||||
if (project_root / "tests").is_dir():
|
||||
return "tests"
|
||||
|
||||
return "src/test/java"
|
||||
|
||||
|
||||
def detect_java_test_framework(project_root: Path) -> str:
|
||||
"""Detect the Java test framework in use."""
|
||||
pom_path = project_root / "pom.xml"
|
||||
if pom_path.exists():
|
||||
try:
|
||||
content = pom_path.read_text(encoding="utf-8")
|
||||
if "junit-jupiter" in content or "junit.jupiter" in content:
|
||||
return "junit5"
|
||||
if "junit" in content.lower():
|
||||
return "junit4"
|
||||
if "testng" in content.lower():
|
||||
return "testng"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
gradle_file = project_root / "build.gradle"
|
||||
if gradle_file.exists():
|
||||
try:
|
||||
content = gradle_file.read_text(encoding="utf-8")
|
||||
if "junit-jupiter" in content or "useJUnitPlatform" in content:
|
||||
return "junit5"
|
||||
if "junit" in content.lower():
|
||||
return "junit4"
|
||||
if "testng" in content.lower():
|
||||
return "testng"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return "junit5" # Default to JUnit 5
|
||||
|
||||
|
||||
def init_java_project() -> None:
|
||||
"""Initialize Codeflash for a Java project."""
|
||||
from codeflash.cli_cmds.cmd_init import install_github_actions, install_github_app, prompt_api_key
|
||||
|
||||
lang_panel = Panel(
|
||||
Text(
|
||||
"Java project detected!\n\nI'll help you set up Codeflash for your project.",
|
||||
style="cyan",
|
||||
justify="center",
|
||||
),
|
||||
title="Java Setup",
|
||||
border_style="bright_red",
|
||||
)
|
||||
console.print(lang_panel)
|
||||
console.print()
|
||||
|
||||
did_add_new_key = prompt_api_key()
|
||||
|
||||
should_modify, _config = should_modify_java_config()
|
||||
|
||||
# Default git remote
|
||||
git_remote = "origin"
|
||||
|
||||
if should_modify:
|
||||
setup_info = collect_java_setup_info()
|
||||
git_remote = setup_info.git_remote or "origin"
|
||||
configured = configure_java_project(setup_info)
|
||||
if not configured:
|
||||
apologize_and_exit()
|
||||
|
||||
install_github_app(git_remote)
|
||||
|
||||
install_github_actions(override_formatter_check=True)
|
||||
|
||||
# Show completion message
|
||||
usage_table = Table(show_header=False, show_lines=False, border_style="dim")
|
||||
usage_table.add_column("Command", style="cyan")
|
||||
usage_table.add_column("Description", style="white")
|
||||
|
||||
usage_table.add_row("codeflash --file <path-to-file> --function <function-name>", "Optimize a specific function")
|
||||
usage_table.add_row("codeflash --all", "Optimize all functions in all files")
|
||||
usage_table.add_row("codeflash --help", "See all available options")
|
||||
|
||||
completion_message = "Codeflash is now set up for your Java project!\n\nYou can now run any of these commands:"
|
||||
|
||||
if did_add_new_key:
|
||||
completion_message += "\n\nDon't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
|
||||
if os.name == "nt":
|
||||
reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}"
|
||||
else:
|
||||
reload_cmd = f"source {get_shell_rc_path()}"
|
||||
completion_message += f"\nOr run: {reload_cmd}"
|
||||
|
||||
completion_panel = Panel(
|
||||
Group(Text(completion_message, style="bold green"), Text(""), usage_table),
|
||||
title="Setup Complete!",
|
||||
border_style="bright_green",
|
||||
padding=(1, 2),
|
||||
)
|
||||
console.print(completion_panel)
|
||||
|
||||
ph("cli-java-installation-successful", {"did_add_new_key": did_add_new_key})
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def should_modify_java_config() -> tuple[bool, dict[str, Any] | None]:
|
||||
"""Check if the project already has Codeflash config."""
|
||||
from rich.prompt import Confirm
|
||||
|
||||
project_root = Path.cwd()
|
||||
|
||||
# Check for existing codeflash config in pom.xml or a separate config file
|
||||
codeflash_config_path = project_root / "codeflash.toml"
|
||||
if codeflash_config_path.exists():
|
||||
return Confirm.ask(
|
||||
"A Codeflash config already exists. Do you want to re-configure it?",
|
||||
default=False,
|
||||
show_default=True,
|
||||
), None
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def collect_java_setup_info() -> JavaSetupInfo:
|
||||
"""Collect setup information for Java projects."""
|
||||
from rich.prompt import Confirm
|
||||
|
||||
from codeflash.cli_cmds.cmd_init import ask_for_telemetry
|
||||
|
||||
curdir = Path.cwd()
|
||||
|
||||
if not os.access(curdir, os.W_OK):
|
||||
click.echo(f"The current directory isn't writable, please check your folder permissions and try again.{LF}")
|
||||
sys.exit(1)
|
||||
|
||||
# Auto-detect values
|
||||
build_tool = detect_java_build_tool(curdir)
|
||||
detected_source_root = detect_java_source_root(curdir)
|
||||
detected_test_root = detect_java_test_root(curdir)
|
||||
detected_test_framework = detect_java_test_framework(curdir)
|
||||
|
||||
# Build detection summary
|
||||
build_tool_name = build_tool.name.lower() if build_tool != JavaBuildTool.UNKNOWN else "unknown"
|
||||
detection_table = Table(show_header=False, box=None, padding=(0, 2))
|
||||
detection_table.add_column("Setting", style="cyan")
|
||||
detection_table.add_column("Value", style="green")
|
||||
detection_table.add_row("Build tool", build_tool_name)
|
||||
detection_table.add_row("Source root", detected_source_root)
|
||||
detection_table.add_row("Test root", detected_test_root)
|
||||
detection_table.add_row("Test framework", detected_test_framework)
|
||||
|
||||
detection_panel = Panel(
|
||||
Group(Text("Auto-detected settings for your Java project:\n", style="cyan"), detection_table),
|
||||
title="Auto-Detection Results",
|
||||
border_style="bright_blue",
|
||||
)
|
||||
console.print(detection_panel)
|
||||
console.print()
|
||||
|
||||
# Ask if user wants to change any settings
|
||||
module_root_override = None
|
||||
test_root_override = None
|
||||
formatter_override = None
|
||||
|
||||
if Confirm.ask("Would you like to change any of these settings?", default=False):
|
||||
# Source root override
|
||||
module_root_override = _prompt_directory_override(
|
||||
"source", detected_source_root, curdir
|
||||
)
|
||||
|
||||
# Test root override
|
||||
test_root_override = _prompt_directory_override(
|
||||
"test", detected_test_root, curdir
|
||||
)
|
||||
|
||||
# Formatter override
|
||||
formatter_questions = [
|
||||
inquirer.List(
|
||||
"formatter",
|
||||
message="Which code formatter do you use?",
|
||||
choices=[
|
||||
(f"keep detected (google-java-format)", "keep"),
|
||||
("google-java-format", "google-java-format"),
|
||||
("spotless", "spotless"),
|
||||
("other", "other"),
|
||||
("don't use a formatter", "disabled"),
|
||||
],
|
||||
default="keep",
|
||||
carousel=True,
|
||||
)
|
||||
]
|
||||
|
||||
formatter_answers = inquirer.prompt(formatter_questions, theme=_get_theme())
|
||||
if not formatter_answers:
|
||||
apologize_and_exit()
|
||||
|
||||
formatter_choice = formatter_answers["formatter"]
|
||||
if formatter_choice != "keep":
|
||||
formatter_override = get_java_formatter_cmd(formatter_choice, build_tool)
|
||||
|
||||
ph("cli-java-formatter-provided", {"overridden": formatter_override is not None})
|
||||
|
||||
# Git remote
|
||||
git_remote = _get_git_remote_for_setup()
|
||||
|
||||
# Telemetry
|
||||
disable_telemetry = not ask_for_telemetry()
|
||||
|
||||
return JavaSetupInfo(
|
||||
module_root_override=module_root_override,
|
||||
test_root_override=test_root_override,
|
||||
formatter_override=formatter_override,
|
||||
git_remote=git_remote,
|
||||
disable_telemetry=disable_telemetry,
|
||||
)
|
||||
|
||||
|
||||
def _prompt_directory_override(dir_type: str, detected: str, curdir: Path) -> str | None:
|
||||
"""Prompt for a directory override."""
|
||||
keep_detected_option = f"keep detected ({detected})"
|
||||
custom_dir_option = "enter a custom directory..."
|
||||
|
||||
# Get subdirectories that might be relevant
|
||||
subdirs = [d.name for d in curdir.iterdir() if d.is_dir() and not d.name.startswith(".")]
|
||||
subdirs = [d for d in subdirs if d not in ("target", "build", ".git", ".idea", detected)]
|
||||
|
||||
options = [keep_detected_option] + subdirs[:5] + [custom_dir_option]
|
||||
|
||||
questions = [
|
||||
inquirer.List(
|
||||
f"{dir_type}_root",
|
||||
message=f"Which directory contains your {dir_type} code?",
|
||||
choices=options,
|
||||
default=keep_detected_option,
|
||||
carousel=True,
|
||||
)
|
||||
]
|
||||
|
||||
answers = inquirer.prompt(questions, theme=_get_theme())
|
||||
if not answers:
|
||||
apologize_and_exit()
|
||||
|
||||
answer = answers[f"{dir_type}_root"]
|
||||
if answer == keep_detected_option:
|
||||
return None
|
||||
elif answer == custom_dir_option:
|
||||
return _prompt_custom_directory(dir_type)
|
||||
else:
|
||||
return answer
|
||||
|
||||
|
||||
def _prompt_custom_directory(dir_type: str) -> str:
|
||||
"""Prompt for a custom directory path."""
|
||||
while True:
|
||||
custom_questions = [
|
||||
inquirer.Path(
|
||||
"custom_path",
|
||||
message=f"Enter the path to your {dir_type} directory",
|
||||
path_type=inquirer.Path.DIRECTORY,
|
||||
exists=True,
|
||||
)
|
||||
]
|
||||
|
||||
custom_answers = inquirer.prompt(custom_questions, theme=_get_theme())
|
||||
if not custom_answers:
|
||||
apologize_and_exit()
|
||||
|
||||
custom_path_str = str(custom_answers["custom_path"])
|
||||
is_valid, error_msg = validate_relative_directory_path(custom_path_str)
|
||||
if is_valid:
|
||||
return custom_path_str
|
||||
|
||||
click.echo(f"Invalid path: {error_msg}")
|
||||
click.echo("Please enter a valid relative directory path.")
|
||||
console.print()
|
||||
|
||||
|
||||
def _get_git_remote_for_setup() -> str:
|
||||
"""Get git remote for project setup."""
|
||||
try:
|
||||
repo = Repo(Path.cwd(), search_parent_directories=True)
|
||||
git_remotes = get_git_remotes(repo)
|
||||
if not git_remotes:
|
||||
return ""
|
||||
|
||||
if len(git_remotes) == 1:
|
||||
return git_remotes[0]
|
||||
|
||||
git_panel = Panel(
|
||||
Text(
|
||||
"Configure Git Remote for Pull Requests.\n\nCodeflash will use this remote to create pull requests.",
|
||||
style="blue",
|
||||
),
|
||||
title="Git Remote Setup",
|
||||
border_style="bright_blue",
|
||||
)
|
||||
console.print(git_panel)
|
||||
console.print()
|
||||
|
||||
git_questions = [
|
||||
inquirer.List(
|
||||
"git_remote",
|
||||
message="Which git remote should Codeflash use?",
|
||||
choices=git_remotes,
|
||||
default="origin",
|
||||
carousel=True,
|
||||
)
|
||||
]
|
||||
|
||||
git_answers = inquirer.prompt(git_questions, theme=_get_theme())
|
||||
return git_answers["git_remote"] if git_answers else git_remotes[0]
|
||||
except InvalidGitRepositoryError:
|
||||
return ""
|
||||
|
||||
|
||||
def get_java_formatter_cmd(formatter: str, build_tool: JavaBuildTool) -> list[str]:
|
||||
"""Get formatter commands for Java."""
|
||||
if formatter == "google-java-format":
|
||||
return ["google-java-format --replace $file"]
|
||||
if formatter == "spotless":
|
||||
if build_tool == JavaBuildTool.MAVEN:
|
||||
return ["mvn spotless:apply -DspotlessFiles=$file"]
|
||||
elif build_tool == JavaBuildTool.GRADLE:
|
||||
return ["./gradlew spotlessApply"]
|
||||
return ["spotless $file"]
|
||||
if formatter == "other":
|
||||
click.echo("In codeflash.toml, please replace 'your-formatter' with your formatter command.")
|
||||
return ["your-formatter $file"]
|
||||
return ["disabled"]
|
||||
|
||||
|
||||
def configure_java_project(setup_info: JavaSetupInfo) -> bool:
|
||||
"""Configure codeflash.toml for Java projects."""
|
||||
import tomlkit
|
||||
|
||||
codeflash_config_path = Path.cwd() / "codeflash.toml"
|
||||
|
||||
# Build config
|
||||
config: dict[str, Any] = {}
|
||||
|
||||
# Detect values
|
||||
curdir = Path.cwd()
|
||||
source_root = setup_info.module_root_override or detect_java_source_root(curdir)
|
||||
test_root = setup_info.test_root_override or detect_java_test_root(curdir)
|
||||
|
||||
config["module-root"] = source_root
|
||||
config["tests-root"] = test_root
|
||||
|
||||
# Formatter
|
||||
if setup_info.formatter_override is not None:
|
||||
if setup_info.formatter_override != ["disabled"]:
|
||||
config["formatter-cmds"] = setup_info.formatter_override
|
||||
else:
|
||||
config["formatter-cmds"] = []
|
||||
|
||||
# Git remote
|
||||
if setup_info.git_remote and setup_info.git_remote not in ("", "origin"):
|
||||
config["git-remote"] = setup_info.git_remote
|
||||
|
||||
# User preferences
|
||||
if setup_info.disable_telemetry:
|
||||
config["disable-telemetry"] = True
|
||||
|
||||
if setup_info.ignore_paths:
|
||||
config["ignore-paths"] = setup_info.ignore_paths
|
||||
|
||||
if setup_info.benchmarks_root:
|
||||
config["benchmarks-root"] = setup_info.benchmarks_root
|
||||
|
||||
try:
|
||||
# Create TOML document
|
||||
doc = tomlkit.document()
|
||||
doc.add(tomlkit.comment("Codeflash configuration for Java project"))
|
||||
doc.add(tomlkit.nl())
|
||||
|
||||
codeflash_table = tomlkit.table()
|
||||
for key, value in config.items():
|
||||
codeflash_table.add(key, value)
|
||||
|
||||
doc.add("tool", tomlkit.table())
|
||||
doc["tool"]["codeflash"] = codeflash_table
|
||||
|
||||
with codeflash_config_path.open("w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(doc))
|
||||
|
||||
click.echo(f"Created Codeflash configuration in {codeflash_config_path}")
|
||||
click.echo()
|
||||
return True
|
||||
except OSError as e:
|
||||
click.echo(f"Failed to create codeflash.toml: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# GitHub Actions Workflow Helpers for Java
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def get_java_runtime_setup_steps(build_tool: JavaBuildTool) -> str:
|
||||
"""Generate the appropriate Java setup steps for GitHub Actions."""
|
||||
java_setup = """- name: Set up JDK 17
|
||||
uses: actions/setup-java@v4
|
||||
with:
|
||||
java-version: '17'
|
||||
distribution: 'temurin'"""
|
||||
|
||||
if build_tool == JavaBuildTool.MAVEN:
|
||||
java_setup += """
|
||||
cache: 'maven'"""
|
||||
elif build_tool == JavaBuildTool.GRADLE:
|
||||
java_setup += """
|
||||
cache: 'gradle'"""
|
||||
|
||||
return java_setup
|
||||
|
||||
|
||||
def get_java_dependency_installation_commands(build_tool: JavaBuildTool) -> str:
|
||||
"""Generate commands to install Java dependencies."""
|
||||
if build_tool == JavaBuildTool.MAVEN:
|
||||
return "mvn dependency:resolve"
|
||||
if build_tool == JavaBuildTool.GRADLE:
|
||||
return "./gradlew dependencies"
|
||||
return "mvn dependency:resolve"
|
||||
|
||||
|
||||
def get_java_test_command(build_tool: JavaBuildTool) -> str:
|
||||
"""Get the test command for Java projects."""
|
||||
if build_tool == JavaBuildTool.MAVEN:
|
||||
return "mvn test"
|
||||
if build_tool == JavaBuildTool.GRADLE:
|
||||
return "./gradlew test"
|
||||
return "mvn test"
|
||||
41
codeflash/cli_cmds/workflows/codeflash-optimize-java.yaml
Normal file
41
codeflash/cli_cmds/workflows/codeflash-optimize-java.yaml
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
name: Codeflash
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
# So that this workflow only runs when code within the target module is modified
|
||||
- '{{ codeflash_module_path }}'
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
# Any new push to the PR will cancel the previous run, so that only the latest code is optimized
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
|
||||
jobs:
|
||||
optimize:
|
||||
name: Optimize new code
|
||||
# Don't run codeflash on codeflash-ai[bot] commits, prevent duplicate optimizations
|
||||
if: ${{ github.actor != 'codeflash-ai[bot]' }}
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
|
||||
{{ working_directory }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up JDK 17
|
||||
uses: actions/setup-java@v4
|
||||
with:
|
||||
java-version: '17'
|
||||
distribution: 'temurin'
|
||||
cache: '{{ java_build_tool }}'
|
||||
- name: Install Dependencies
|
||||
run: {{ install_dependencies_command }}
|
||||
- name: Install Codeflash
|
||||
run: pip install codeflash
|
||||
- name: Codeflash Optimization
|
||||
run: codeflash
|
||||
|
|
@ -4,6 +4,7 @@ import ast
|
|||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
|
||||
import libcst as cst
|
||||
|
|
@ -732,12 +733,29 @@ def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStrin
|
|||
module_optimized_code = file_to_code_context["None"]
|
||||
logger.debug(f"Using code block with None file_path for {relative_path}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
|
||||
"re-check your 'markdown code structure'"
|
||||
f"existing files are {file_to_code_context.keys()}"
|
||||
)
|
||||
module_optimized_code = ""
|
||||
# Fallback: try to match by just the filename (for Java/JS where the AI
|
||||
# might return just the class name like "Algorithms.java" instead of
|
||||
# the full path like "src/main/java/com/example/Algorithms.java")
|
||||
target_filename = relative_path.name
|
||||
for file_path_str, code in file_to_code_context.items():
|
||||
if file_path_str and Path(file_path_str).name == target_filename:
|
||||
module_optimized_code = code
|
||||
logger.debug(f"Matched {file_path_str} to {relative_path} by filename")
|
||||
break
|
||||
|
||||
if module_optimized_code is None:
|
||||
# Also try matching if there's only one code file
|
||||
if len(file_to_code_context) == 1:
|
||||
only_key = next(iter(file_to_code_context.keys()))
|
||||
module_optimized_code = file_to_code_context[only_key]
|
||||
logger.debug(f"Using only code block {only_key} for {relative_path}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
|
||||
"re-check your 'markdown code structure'"
|
||||
f"existing files are {file_to_code_context.keys()}"
|
||||
)
|
||||
module_optimized_code = ""
|
||||
return module_optimized_code
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from codeflash.cli_cmds.console import logger
|
|||
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages import is_java, is_javascript
|
||||
from codeflash.models.models import FunctionParent, TestingMode, VerificationType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -709,6 +710,21 @@ def inject_profiling_into_existing_test(
|
|||
tests_project_root: Path,
|
||||
mode: TestingMode = TestingMode.BEHAVIOR,
|
||||
) -> tuple[bool, str | None]:
|
||||
# Route to language-specific implementations
|
||||
if is_javascript():
|
||||
from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test
|
||||
|
||||
return inject_profiling_into_existing_js_test(
|
||||
test_path, call_positions, function_to_optimize, tests_project_root, mode.value
|
||||
)
|
||||
|
||||
if is_java():
|
||||
from codeflash.languages.java.instrumentation import instrument_existing_test
|
||||
|
||||
return instrument_existing_test(
|
||||
test_path, call_positions, function_to_optimize, tests_project_root, mode.value
|
||||
)
|
||||
|
||||
if function_to_optimize.is_async:
|
||||
return inject_async_profiling_into_existing_test(
|
||||
test_path, call_positions, function_to_optimize, tests_project_root, mode
|
||||
|
|
|
|||
|
|
@ -3,6 +3,13 @@
|
|||
This module provides functionality to instrument Java code for:
|
||||
1. Behavior capture - recording inputs/outputs for verification
|
||||
2. Benchmarking - measuring execution time
|
||||
|
||||
Timing instrumentation adds System.nanoTime() calls around the function being tested
|
||||
and prints timing markers in a format compatible with Python/JS implementations:
|
||||
Start: !$######testModule:testClass:funcName:loopIndex:iterationId######$!
|
||||
End: !######testModule:testClass:funcName:loopIndex:iterationId:durationNs######!
|
||||
|
||||
This allows codeflash to extract timing data from stdout for accurate benchmarking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -30,54 +37,21 @@ def _get_function_name(func: Any) -> str:
|
|||
return func.function_name
|
||||
raise AttributeError(f"Cannot get function name from {type(func)}")
|
||||
|
||||
# Template for behavior capture instrumentation
|
||||
BEHAVIOR_CAPTURE_IMPORT = "import com.codeflash.CodeFlash;"
|
||||
|
||||
BEHAVIOR_CAPTURE_BEFORE = """
|
||||
// CodeFlash behavior capture - start
|
||||
long __codeflash_call_id_{call_id} = System.nanoTime();
|
||||
CodeFlash.recordInput(__codeflash_call_id_{call_id}, "{method_id}", CodeFlash.serialize({args}));
|
||||
long __codeflash_start_{call_id} = System.nanoTime();
|
||||
"""
|
||||
|
||||
BEHAVIOR_CAPTURE_AFTER_RETURN = """
|
||||
// CodeFlash behavior capture - end
|
||||
long __codeflash_end_{call_id} = System.nanoTime();
|
||||
CodeFlash.recordOutput(__codeflash_call_id_{call_id}, "{method_id}", CodeFlash.serialize(__codeflash_result_{call_id}), __codeflash_end_{call_id} - __codeflash_start_{call_id});
|
||||
"""
|
||||
|
||||
BEHAVIOR_CAPTURE_AFTER_VOID = """
|
||||
// CodeFlash behavior capture - end
|
||||
long __codeflash_end_{call_id} = System.nanoTime();
|
||||
CodeFlash.recordOutput(__codeflash_call_id_{call_id}, "{method_id}", "null", __codeflash_end_{call_id} - __codeflash_start_{call_id});
|
||||
"""
|
||||
|
||||
# Template for benchmark instrumentation
|
||||
BENCHMARK_IMPORT = """import com.codeflash.Blackhole;
|
||||
import com.codeflash.BenchmarkContext;
|
||||
import com.codeflash.BenchmarkResult;"""
|
||||
|
||||
BENCHMARK_WRAPPER_TEMPLATE = """
|
||||
// CodeFlash benchmark wrapper
|
||||
public void __codeflash_benchmark_{method_name}(int iterations) {{
|
||||
// Warmup
|
||||
for (int i = 0; i < Math.min(iterations / 10, 100); i++) {{
|
||||
{warmup_call}
|
||||
}}
|
||||
|
||||
// Measurement
|
||||
long[] measurements = new long[iterations];
|
||||
for (int i = 0; i < iterations; i++) {{
|
||||
long start = System.nanoTime();
|
||||
{measurement_call}
|
||||
long end = System.nanoTime();
|
||||
measurements[i] = end - start;
|
||||
}}
|
||||
|
||||
BenchmarkResult result = new BenchmarkResult("{method_id}", measurements);
|
||||
CodeFlash.recordBenchmarkResult("{method_id}", result);
|
||||
}}
|
||||
"""
|
||||
def _get_qualified_name(func: Any) -> str:
|
||||
"""Get the qualified name from either FunctionInfo or FunctionToOptimize."""
|
||||
if hasattr(func, "qualified_name"):
|
||||
return func.qualified_name
|
||||
# Build qualified name from function_name and parents
|
||||
if hasattr(func, "function_name"):
|
||||
parts = []
|
||||
if hasattr(func, "parents") and func.parents:
|
||||
for parent in func.parents:
|
||||
if hasattr(parent, "name"):
|
||||
parts.append(parent.name)
|
||||
parts.append(func.function_name)
|
||||
return ".".join(parts)
|
||||
return str(func)
|
||||
|
||||
|
||||
def instrument_for_behavior(
|
||||
|
|
@ -87,34 +61,361 @@ def instrument_for_behavior(
|
|||
) -> str:
|
||||
"""Add behavior instrumentation to capture inputs/outputs.
|
||||
|
||||
Wraps function calls to record arguments and return values
|
||||
for behavioral verification.
|
||||
For Java, we don't modify the test file for behavior capture.
|
||||
Instead, we rely on JUnit test results (pass/fail) to verify correctness.
|
||||
The test file is returned unchanged.
|
||||
|
||||
Args:
|
||||
source: Source code to instrument.
|
||||
functions: Functions to add behavior capture.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Source code (unchanged for Java).
|
||||
|
||||
"""
|
||||
# For Java, we don't need to instrument tests for behavior capture.
|
||||
# The JUnit test results (pass/fail) serve as the verification mechanism.
|
||||
if functions:
|
||||
func_name = _get_function_name(functions[0])
|
||||
logger.debug("Java behavior testing for %s - using JUnit pass/fail results", func_name)
|
||||
return source
|
||||
|
||||
|
||||
def instrument_for_benchmarking(
|
||||
test_source: str,
|
||||
target_function: FunctionInfo,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
"""Add timing instrumentation to test code.
|
||||
|
||||
For Java, we rely on Maven Surefire's timing information rather than
|
||||
modifying the test code. The test file is returned unchanged.
|
||||
|
||||
Args:
|
||||
test_source: Test source code to instrument.
|
||||
target_function: Function being benchmarked.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Test source code (unchanged for Java).
|
||||
|
||||
"""
|
||||
func_name = _get_function_name(target_function)
|
||||
logger.debug("Java benchmarking for %s - using Maven Surefire timing", func_name)
|
||||
return test_source
|
||||
|
||||
|
||||
def instrument_existing_test(
|
||||
test_path: Path,
|
||||
call_positions: Sequence,
|
||||
function_to_optimize: Any, # FunctionInfo or FunctionToOptimize
|
||||
tests_project_root: Path,
|
||||
mode: str, # "behavior" or "performance"
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
output_class_suffix: str | None = None, # Suffix for renamed class
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Inject profiling code into an existing test file.
|
||||
|
||||
For Java, this:
|
||||
1. Renames the class to match the new file name (Java requires class name = file name)
|
||||
2. Adds timing instrumentation to test methods (for performance mode)
|
||||
|
||||
Args:
|
||||
test_path: Path to the test file.
|
||||
call_positions: List of code positions where the function is called.
|
||||
function_to_optimize: The function being optimized.
|
||||
tests_project_root: Root directory of tests.
|
||||
mode: Testing mode - "behavior" or "performance".
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
output_class_suffix: Optional suffix for the renamed class.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, modified_source).
|
||||
|
||||
"""
|
||||
try:
|
||||
source = test_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.error("Failed to read test file %s: %s", test_path, e)
|
||||
return False, f"Failed to read test file: {e}"
|
||||
|
||||
func_name = _get_function_name(function_to_optimize)
|
||||
|
||||
# Get the original class name from the file name
|
||||
original_class_name = test_path.stem # e.g., "AlgorithmsTest"
|
||||
|
||||
# Determine the new class name based on mode
|
||||
if mode == "behavior":
|
||||
new_class_name = f"{original_class_name}__perfinstrumented"
|
||||
else:
|
||||
new_class_name = f"{original_class_name}__perfonlyinstrumented"
|
||||
|
||||
# Rename the class declaration in the source
|
||||
# Pattern: "public class ClassName" or "class ClassName"
|
||||
pattern = rf'\b(public\s+)?class\s+{re.escape(original_class_name)}\b'
|
||||
replacement = rf'\1class {new_class_name}'
|
||||
modified_source = re.sub(pattern, replacement, source)
|
||||
|
||||
# For performance mode, add timing instrumentation to test methods
|
||||
if mode == "performance":
|
||||
modified_source = _add_timing_instrumentation(
|
||||
modified_source,
|
||||
new_class_name,
|
||||
func_name,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Java %s testing for %s: renamed class %s -> %s",
|
||||
mode,
|
||||
func_name,
|
||||
original_class_name,
|
||||
new_class_name,
|
||||
)
|
||||
|
||||
return True, modified_source
|
||||
|
||||
|
||||
def _add_timing_instrumentation(source: str, class_name: str, func_name: str) -> str:
|
||||
"""Add timing instrumentation to test methods.
|
||||
|
||||
For each @Test method, this adds:
|
||||
1. Start timing marker printed at the beginning
|
||||
2. End timing marker printed at the end (in a finally block)
|
||||
|
||||
Timing markers format:
|
||||
Start: !$######testModule:testClass:funcName:loopIndex:iterationId######$!
|
||||
End: !######testModule:testClass:funcName:loopIndex:iterationId:durationNs######!
|
||||
|
||||
Args:
|
||||
source: The test source code.
|
||||
class_name: Name of the test class.
|
||||
func_name: Name of the function being tested.
|
||||
|
||||
Returns:
|
||||
Instrumented source code.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
# Find all @Test methods and add timing around their bodies
|
||||
# Pattern matches: @Test (with optional parameters) followed by method declaration
|
||||
# We process line by line for cleaner handling
|
||||
|
||||
if not functions:
|
||||
return source
|
||||
lines = source.split('\n')
|
||||
result = []
|
||||
i = 0
|
||||
iteration_counter = 0
|
||||
|
||||
# Add import if not present
|
||||
if BEHAVIOR_CAPTURE_IMPORT not in source:
|
||||
source = _add_import(source, BEHAVIOR_CAPTURE_IMPORT)
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.strip()
|
||||
|
||||
# Find and instrument each function
|
||||
for func in functions:
|
||||
source = _instrument_function_behavior(source, func, analyzer)
|
||||
# Look for @Test annotation
|
||||
if stripped.startswith('@Test'):
|
||||
result.append(line)
|
||||
i += 1
|
||||
|
||||
# Collect any additional annotations
|
||||
while i < len(lines) and lines[i].strip().startswith('@'):
|
||||
result.append(lines[i])
|
||||
i += 1
|
||||
|
||||
# Now find the method signature and opening brace
|
||||
method_lines = []
|
||||
while i < len(lines):
|
||||
method_lines.append(lines[i])
|
||||
if '{' in lines[i]:
|
||||
break
|
||||
i += 1
|
||||
|
||||
# Add the method signature lines
|
||||
for ml in method_lines:
|
||||
result.append(ml)
|
||||
i += 1
|
||||
|
||||
# We're now inside the method body
|
||||
iteration_counter += 1
|
||||
iter_id = iteration_counter
|
||||
|
||||
# Add timing start code
|
||||
indent = " "
|
||||
timing_start_code = [
|
||||
f"{indent}// Codeflash timing instrumentation",
|
||||
f'{indent}int _cf_loop{iter_id} = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX") != null ? System.getenv("CODEFLASH_LOOP_INDEX") : "1");',
|
||||
f"{indent}int _cf_iter{iter_id} = {iter_id};",
|
||||
f'{indent}String _cf_mod{iter_id} = "{class_name}";',
|
||||
f'{indent}String _cf_cls{iter_id} = "{class_name}";',
|
||||
f'{indent}String _cf_fn{iter_id} = "{func_name}";',
|
||||
f'{indent}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + "######$!");',
|
||||
f"{indent}long _cf_start{iter_id} = System.nanoTime();",
|
||||
f"{indent}try {{",
|
||||
]
|
||||
result.extend(timing_start_code)
|
||||
|
||||
# Collect method body until we find matching closing brace
|
||||
brace_depth = 1
|
||||
body_lines = []
|
||||
|
||||
while i < len(lines) and brace_depth > 0:
|
||||
body_line = lines[i]
|
||||
# Count braces (simple approach - doesn't handle strings/comments perfectly)
|
||||
for ch in body_line:
|
||||
if ch == '{':
|
||||
brace_depth += 1
|
||||
elif ch == '}':
|
||||
brace_depth -= 1
|
||||
|
||||
if brace_depth > 0:
|
||||
body_lines.append(body_line)
|
||||
i += 1
|
||||
else:
|
||||
# This line contains the closing brace, but we've hit depth 0
|
||||
# Add indented body lines
|
||||
for bl in body_lines:
|
||||
result.append(" " + bl)
|
||||
|
||||
# Add finally block
|
||||
timing_end_code = [
|
||||
f"{indent}}} finally {{",
|
||||
f"{indent} long _cf_end{iter_id} = System.nanoTime();",
|
||||
f"{indent} long _cf_dur{iter_id} = _cf_end{iter_id} - _cf_start{iter_id};",
|
||||
f'{indent} System.out.println("!######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":" + _cf_iter{iter_id} + ":" + _cf_dur{iter_id} + "######!");',
|
||||
f"{indent}}}",
|
||||
" }", # Method closing brace
|
||||
]
|
||||
result.extend(timing_end_code)
|
||||
i += 1
|
||||
else:
|
||||
result.append(line)
|
||||
i += 1
|
||||
|
||||
return '\n'.join(result)
|
||||
|
||||
|
||||
def create_benchmark_test(
|
||||
target_function: FunctionInfo,
|
||||
test_setup_code: str,
|
||||
invocation_code: str,
|
||||
iterations: int = 1000,
|
||||
) -> str:
|
||||
"""Create a benchmark test for a function.
|
||||
|
||||
Args:
|
||||
target_function: The function to benchmark.
|
||||
test_setup_code: Code to set up the test (create instances, etc.).
|
||||
invocation_code: Code that invokes the function.
|
||||
iterations: Number of benchmark iterations.
|
||||
|
||||
Returns:
|
||||
Complete benchmark test source code.
|
||||
|
||||
"""
|
||||
method_name = _get_function_name(target_function)
|
||||
method_id = _get_qualified_name(target_function)
|
||||
class_name = getattr(target_function, "class_name", None) or "Target"
|
||||
|
||||
benchmark_code = f"""
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
|
||||
/**
|
||||
* Benchmark test for {method_name}.
|
||||
* Generated by CodeFlash.
|
||||
*/
|
||||
public class {class_name}Benchmark {{
|
||||
|
||||
@Test
|
||||
@DisplayName("Benchmark {method_name}")
|
||||
public void benchmark{method_name.capitalize()}() {{
|
||||
{test_setup_code}
|
||||
|
||||
// Warmup phase
|
||||
for (int i = 0; i < {iterations // 10}; i++) {{
|
||||
{invocation_code};
|
||||
}}
|
||||
|
||||
// Measurement phase
|
||||
long startTime = System.nanoTime();
|
||||
for (int i = 0; i < {iterations}; i++) {{
|
||||
{invocation_code};
|
||||
}}
|
||||
long endTime = System.nanoTime();
|
||||
|
||||
long totalNanos = endTime - startTime;
|
||||
long avgNanos = totalNanos / {iterations};
|
||||
|
||||
System.out.println("CODEFLASH_BENCHMARK:{method_id}:total_ns=" + totalNanos + ",avg_ns=" + avgNanos + ",iterations={iterations}");
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
return benchmark_code
|
||||
|
||||
|
||||
def remove_instrumentation(source: str) -> str:
|
||||
"""Remove CodeFlash instrumentation from source code.
|
||||
|
||||
For Java, since we don't add instrumentation, this is a no-op.
|
||||
|
||||
Args:
|
||||
source: Source code.
|
||||
|
||||
Returns:
|
||||
Source unchanged.
|
||||
|
||||
"""
|
||||
return source
|
||||
|
||||
|
||||
def instrument_generated_java_test(
|
||||
test_code: str,
|
||||
function_name: str,
|
||||
qualified_name: str,
|
||||
mode: str, # "behavior" or "performance"
|
||||
) -> str:
|
||||
"""Instrument a generated Java test for behavior or performance testing.
|
||||
|
||||
Args:
|
||||
test_code: The generated test source code.
|
||||
function_name: Name of the function being tested.
|
||||
qualified_name: Fully qualified name of the function.
|
||||
mode: "behavior" for behavior capture or "performance" for timing.
|
||||
|
||||
Returns:
|
||||
Instrumented test source code.
|
||||
|
||||
"""
|
||||
# Extract class name from the test code
|
||||
class_match = re.search(r'\bclass\s+(\w+)', test_code)
|
||||
if not class_match:
|
||||
logger.warning("Could not find class name in generated test")
|
||||
return test_code
|
||||
|
||||
original_class_name = class_match.group(1)
|
||||
|
||||
# Rename class based on mode
|
||||
if mode == "behavior":
|
||||
new_class_name = f"{original_class_name}__perfinstrumented"
|
||||
else:
|
||||
new_class_name = f"{original_class_name}__perfonlyinstrumented"
|
||||
|
||||
# Rename the class in the source
|
||||
modified_code = re.sub(
|
||||
rf'\b(public\s+)?class\s+{re.escape(original_class_name)}\b',
|
||||
rf'\1class {new_class_name}',
|
||||
test_code,
|
||||
)
|
||||
|
||||
# For performance mode, add timing instrumentation
|
||||
if mode == "performance":
|
||||
modified_code = _add_timing_instrumentation(
|
||||
modified_code,
|
||||
new_class_name,
|
||||
function_name,
|
||||
)
|
||||
|
||||
logger.debug("Instrumented generated Java test for %s (mode=%s)", function_name, mode)
|
||||
return modified_code
|
||||
|
||||
|
||||
def _add_import(source: str, import_statement: str) -> str:
|
||||
"""Add an import statement to the source.
|
||||
|
||||
|
|
@ -142,213 +443,3 @@ def _add_import(source: str, import_statement: str) -> str:
|
|||
|
||||
lines.insert(insert_idx, import_statement + "\n")
|
||||
return "".join(lines)
|
||||
|
||||
|
||||
def _instrument_function_behavior(
|
||||
source: str,
|
||||
function: FunctionInfo,
|
||||
analyzer: JavaAnalyzer,
|
||||
) -> str:
|
||||
"""Instrument a single function for behavior capture.
|
||||
|
||||
Args:
|
||||
source: The source code.
|
||||
function: The function to instrument.
|
||||
analyzer: JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Source with function instrumented.
|
||||
|
||||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
|
||||
# Find the method node
|
||||
methods = analyzer.find_methods(source)
|
||||
target_method = None
|
||||
func_name = _get_function_name(function)
|
||||
for method in methods:
|
||||
if method.name == func_name:
|
||||
class_name = getattr(function, "class_name", None)
|
||||
if class_name is None or method.class_name == class_name:
|
||||
target_method = method
|
||||
break
|
||||
|
||||
if not target_method:
|
||||
logger.warning("Could not find method %s for instrumentation", func_name)
|
||||
return source
|
||||
|
||||
# For now, we'll add instrumentation as a simple wrapper
|
||||
# A full implementation would use AST transformation
|
||||
method_id = function.qualified_name
|
||||
call_id = hash(method_id) % 10000
|
||||
|
||||
# Build instrumented version
|
||||
# This is a simplified approach - a full implementation would
|
||||
# parse the method body and instrument each return statement
|
||||
logger.debug("Instrumented method %s for behavior capture", function.name)
|
||||
|
||||
return source
|
||||
|
||||
|
||||
def instrument_for_benchmarking(
|
||||
test_source: str,
|
||||
target_function: FunctionInfo,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
"""Add timing instrumentation to test code.
|
||||
|
||||
Args:
|
||||
test_source: Test source code to instrument.
|
||||
target_function: Function being benchmarked.
|
||||
|
||||
Returns:
|
||||
Instrumented test source code.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
# Add imports if not present
|
||||
if "import com.codeflash" not in test_source:
|
||||
test_source = _add_import(test_source, BENCHMARK_IMPORT)
|
||||
|
||||
# Find calls to the target function in the test and wrap them
|
||||
# This is a simplified implementation
|
||||
logger.debug("Instrumented test for benchmarking %s", _get_function_name(target_function))
|
||||
|
||||
return test_source
|
||||
|
||||
|
||||
def instrument_existing_test(
|
||||
test_path: Path,
|
||||
call_positions: Sequence,
|
||||
function_to_optimize: FunctionInfo,
|
||||
tests_project_root: Path,
|
||||
mode: str, # "behavior" or "performance"
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Inject profiling code into an existing test file.
|
||||
|
||||
Args:
|
||||
test_path: Path to the test file.
|
||||
call_positions: List of code positions where the function is called.
|
||||
function_to_optimize: The function being optimized.
|
||||
tests_project_root: Root directory of tests.
|
||||
mode: Testing mode - "behavior" or "performance".
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, instrumented_code or error message).
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
try:
|
||||
source = test_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
return False, f"Failed to read test file: {e}"
|
||||
|
||||
try:
|
||||
if mode == "behavior":
|
||||
instrumented = instrument_for_behavior(source, [function_to_optimize], analyzer)
|
||||
else:
|
||||
instrumented = instrument_for_benchmarking(source, function_to_optimize, analyzer)
|
||||
|
||||
return True, instrumented
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to instrument test file: %s", e)
|
||||
return False, str(e)
|
||||
|
||||
|
||||
def create_benchmark_test(
|
||||
target_function: FunctionInfo,
|
||||
test_setup_code: str,
|
||||
invocation_code: str,
|
||||
iterations: int = 1000,
|
||||
) -> str:
|
||||
"""Create a benchmark test for a function.
|
||||
|
||||
Args:
|
||||
target_function: The function to benchmark.
|
||||
test_setup_code: Code to set up the test (create instances, etc.).
|
||||
invocation_code: Code that invokes the function.
|
||||
iterations: Number of benchmark iterations.
|
||||
|
||||
Returns:
|
||||
Complete benchmark test source code.
|
||||
|
||||
"""
|
||||
method_name = target_function.name
|
||||
method_id = target_function.qualified_name
|
||||
|
||||
benchmark_code = f"""
|
||||
import com.codeflash.Blackhole;
|
||||
import com.codeflash.BenchmarkContext;
|
||||
import com.codeflash.BenchmarkResult;
|
||||
import com.codeflash.CodeFlash;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class {target_function.class_name or 'Target'}Benchmark {{
|
||||
|
||||
@Test
|
||||
public void benchmark{method_name.capitalize()}() {{
|
||||
{test_setup_code}
|
||||
|
||||
// Warmup phase
|
||||
for (int i = 0; i < {iterations // 10}; i++) {{
|
||||
Blackhole.consume({invocation_code});
|
||||
}}
|
||||
|
||||
// Measurement phase
|
||||
long[] measurements = new long[{iterations}];
|
||||
for (int i = 0; i < {iterations}; i++) {{
|
||||
long start = System.nanoTime();
|
||||
Blackhole.consume({invocation_code});
|
||||
long end = System.nanoTime();
|
||||
measurements[i] = end - start;
|
||||
}}
|
||||
|
||||
BenchmarkResult result = new BenchmarkResult("{method_id}", measurements);
|
||||
CodeFlash.recordBenchmarkResult("{method_id}", result);
|
||||
|
||||
System.out.println("Benchmark complete: " + result);
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
return benchmark_code
|
||||
|
||||
|
||||
def remove_instrumentation(source: str) -> str:
|
||||
"""Remove CodeFlash instrumentation from source code.
|
||||
|
||||
Args:
|
||||
source: Instrumented source code.
|
||||
|
||||
Returns:
|
||||
Source with instrumentation removed.
|
||||
|
||||
"""
|
||||
lines = source.splitlines(keepends=True)
|
||||
result_lines = []
|
||||
skip_until_end = False
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
# Skip CodeFlash instrumentation blocks
|
||||
if "// CodeFlash" in stripped and "start" in stripped:
|
||||
skip_until_end = True
|
||||
continue
|
||||
if skip_until_end:
|
||||
if "// CodeFlash" in stripped and "end" in stripped:
|
||||
skip_until_end = False
|
||||
continue
|
||||
|
||||
# Skip CodeFlash imports
|
||||
if "import com.codeflash" in stripped:
|
||||
continue
|
||||
|
||||
result_lines.append(line)
|
||||
|
||||
return "".join(result_lines)
|
||||
|
|
|
|||
386
codeflash/languages/java/resources/CodeflashHelper.java
Normal file
386
codeflash/languages/java/resources/CodeflashHelper.java
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
package codeflash.runtime;
|
||||
|
||||
import java.io.File;
|
||||
import java.sql.Connection;
|
||||
import java.sql.DriverManager;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Statement;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
/**
|
||||
* Codeflash Helper - Test Instrumentation for Java
|
||||
*
|
||||
* This class provides timing instrumentation for Java tests, mirroring the
|
||||
* behavior of the JavaScript codeflash package.
|
||||
*
|
||||
* Usage in instrumented tests:
|
||||
* import codeflash.runtime.CodeflashHelper;
|
||||
*
|
||||
* // For behavior verification (writes to SQLite):
|
||||
* Object result = CodeflashHelper.capture("testModule", "testClass", "testFunc",
|
||||
* "funcName", () -> targetMethod(arg1, arg2));
|
||||
*
|
||||
* // For performance benchmarking:
|
||||
* Object result = CodeflashHelper.capturePerf("testModule", "testClass", "testFunc",
|
||||
* "funcName", () -> targetMethod(arg1, arg2));
|
||||
*
|
||||
* Environment Variables:
|
||||
* CODEFLASH_OUTPUT_FILE - Path to write results SQLite file
|
||||
* CODEFLASH_LOOP_INDEX - Current benchmark loop iteration (default: 1)
|
||||
* CODEFLASH_TEST_ITERATION - Test iteration number (default: 0)
|
||||
* CODEFLASH_MODE - "behavior" or "performance"
|
||||
*/
|
||||
public class CodeflashHelper {
|
||||
|
||||
private static final String OUTPUT_FILE = System.getenv("CODEFLASH_OUTPUT_FILE");
|
||||
private static final int LOOP_INDEX = parseIntOrDefault(System.getenv("CODEFLASH_LOOP_INDEX"), 1);
|
||||
private static final String MODE = System.getenv("CODEFLASH_MODE");
|
||||
|
||||
// Track invocation counts per test method for unique iteration IDs
|
||||
private static final ConcurrentHashMap<String, AtomicInteger> invocationCounts = new ConcurrentHashMap<>();
|
||||
|
||||
// Database connection (lazily initialized)
|
||||
private static Connection dbConnection = null;
|
||||
private static boolean dbInitialized = false;
|
||||
|
||||
/**
|
||||
* Functional interface for wrapping void method calls.
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface VoidCallable {
|
||||
void call() throws Exception;
|
||||
}
|
||||
|
||||
/**
|
||||
* Functional interface for wrapping method calls that return a value.
|
||||
*/
|
||||
@FunctionalInterface
|
||||
public interface Callable<T> {
|
||||
T call() throws Exception;
|
||||
}
|
||||
|
||||
/**
|
||||
* Capture behavior and timing for a method call that returns a value.
|
||||
*/
|
||||
public static <T> T capture(
|
||||
String testModulePath,
|
||||
String testClassName,
|
||||
String testFunctionName,
|
||||
String functionGettingTested,
|
||||
Callable<T> callable
|
||||
) throws Exception {
|
||||
String invocationKey = testModulePath + ":" + testClassName + ":" + testFunctionName + ":" + functionGettingTested;
|
||||
int iterationId = getNextIterationId(invocationKey);
|
||||
|
||||
long startTime = System.nanoTime();
|
||||
T result;
|
||||
try {
|
||||
result = callable.call();
|
||||
} finally {
|
||||
long endTime = System.nanoTime();
|
||||
long durationNs = endTime - startTime;
|
||||
|
||||
// Write to SQLite for behavior verification
|
||||
writeResultToSqlite(
|
||||
testModulePath,
|
||||
testClassName,
|
||||
testFunctionName,
|
||||
functionGettingTested,
|
||||
LOOP_INDEX,
|
||||
iterationId,
|
||||
durationNs,
|
||||
null, // return_value - TODO: serialize if needed
|
||||
"output"
|
||||
);
|
||||
|
||||
// Print timing marker for stdout parsing (backup method)
|
||||
printTimingMarker(testModulePath, testClassName, functionGettingTested, LOOP_INDEX, iterationId, durationNs);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Capture behavior and timing for a void method call.
|
||||
*/
|
||||
public static void captureVoid(
|
||||
String testModulePath,
|
||||
String testClassName,
|
||||
String testFunctionName,
|
||||
String functionGettingTested,
|
||||
VoidCallable callable
|
||||
) throws Exception {
|
||||
String invocationKey = testModulePath + ":" + testClassName + ":" + testFunctionName + ":" + functionGettingTested;
|
||||
int iterationId = getNextIterationId(invocationKey);
|
||||
|
||||
long startTime = System.nanoTime();
|
||||
try {
|
||||
callable.call();
|
||||
} finally {
|
||||
long endTime = System.nanoTime();
|
||||
long durationNs = endTime - startTime;
|
||||
|
||||
// Write to SQLite
|
||||
writeResultToSqlite(
|
||||
testModulePath,
|
||||
testClassName,
|
||||
testFunctionName,
|
||||
functionGettingTested,
|
||||
LOOP_INDEX,
|
||||
iterationId,
|
||||
durationNs,
|
||||
null,
|
||||
"output"
|
||||
);
|
||||
|
||||
// Print timing marker
|
||||
printTimingMarker(testModulePath, testClassName, functionGettingTested, LOOP_INDEX, iterationId, durationNs);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Capture timing for performance benchmarking (method with return value).
|
||||
*/
|
||||
public static <T> T capturePerf(
|
||||
String testModulePath,
|
||||
String testClassName,
|
||||
String testFunctionName,
|
||||
String functionGettingTested,
|
||||
Callable<T> callable
|
||||
) throws Exception {
|
||||
String invocationKey = testModulePath + ":" + testClassName + ":" + testFunctionName + ":" + functionGettingTested;
|
||||
int iterationId = getNextIterationId(invocationKey);
|
||||
|
||||
// Print start marker
|
||||
printStartMarker(testModulePath, testClassName, functionGettingTested, LOOP_INDEX, iterationId);
|
||||
|
||||
long startTime = System.nanoTime();
|
||||
T result;
|
||||
try {
|
||||
result = callable.call();
|
||||
} finally {
|
||||
long endTime = System.nanoTime();
|
||||
long durationNs = endTime - startTime;
|
||||
|
||||
// Write to SQLite for performance data
|
||||
writeResultToSqlite(
|
||||
testModulePath,
|
||||
testClassName,
|
||||
testFunctionName,
|
||||
functionGettingTested,
|
||||
LOOP_INDEX,
|
||||
iterationId,
|
||||
durationNs,
|
||||
null,
|
||||
"output"
|
||||
);
|
||||
|
||||
// Print end marker with timing
|
||||
printTimingMarker(testModulePath, testClassName, functionGettingTested, LOOP_INDEX, iterationId, durationNs);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Capture timing for performance benchmarking (void method).
|
||||
*/
|
||||
public static void capturePerfVoid(
|
||||
String testModulePath,
|
||||
String testClassName,
|
||||
String testFunctionName,
|
||||
String functionGettingTested,
|
||||
VoidCallable callable
|
||||
) throws Exception {
|
||||
String invocationKey = testModulePath + ":" + testClassName + ":" + testFunctionName + ":" + functionGettingTested;
|
||||
int iterationId = getNextIterationId(invocationKey);
|
||||
|
||||
// Print start marker
|
||||
printStartMarker(testModulePath, testClassName, functionGettingTested, LOOP_INDEX, iterationId);
|
||||
|
||||
long startTime = System.nanoTime();
|
||||
try {
|
||||
callable.call();
|
||||
} finally {
|
||||
long endTime = System.nanoTime();
|
||||
long durationNs = endTime - startTime;
|
||||
|
||||
// Write to SQLite
|
||||
writeResultToSqlite(
|
||||
testModulePath,
|
||||
testClassName,
|
||||
testFunctionName,
|
||||
functionGettingTested,
|
||||
LOOP_INDEX,
|
||||
iterationId,
|
||||
durationNs,
|
||||
null,
|
||||
"output"
|
||||
);
|
||||
|
||||
// Print end marker with timing
|
||||
printTimingMarker(testModulePath, testClassName, functionGettingTested, LOOP_INDEX, iterationId, durationNs);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the next iteration ID for a given invocation key.
|
||||
*/
|
||||
private static int getNextIterationId(String invocationKey) {
|
||||
return invocationCounts.computeIfAbsent(invocationKey, k -> new AtomicInteger(0)).incrementAndGet();
|
||||
}
|
||||
|
||||
/**
|
||||
* Print timing marker to stdout (format matches Python/JS).
|
||||
* Format: !######testModule:testClass:funcName:loopIndex:iterationId:durationNs######!
|
||||
*/
|
||||
private static void printTimingMarker(
|
||||
String testModule,
|
||||
String testClass,
|
||||
String funcName,
|
||||
int loopIndex,
|
||||
int iterationId,
|
||||
long durationNs
|
||||
) {
|
||||
System.out.println("!######" + testModule + ":" + testClass + ":" + funcName + ":" +
|
||||
loopIndex + ":" + iterationId + ":" + durationNs + "######!");
|
||||
}
|
||||
|
||||
/**
|
||||
* Print start marker for performance tests.
|
||||
* Format: !$######testModule:testClass:funcName:loopIndex:iterationId######$!
|
||||
*/
|
||||
private static void printStartMarker(
|
||||
String testModule,
|
||||
String testClass,
|
||||
String funcName,
|
||||
int loopIndex,
|
||||
int iterationId
|
||||
) {
|
||||
System.out.println("!$######" + testModule + ":" + testClass + ":" + funcName + ":" +
|
||||
loopIndex + ":" + iterationId + "######$!");
|
||||
}
|
||||
|
||||
/**
|
||||
* Write test result to SQLite database.
|
||||
*/
|
||||
private static synchronized void writeResultToSqlite(
|
||||
String testModulePath,
|
||||
String testClassName,
|
||||
String testFunctionName,
|
||||
String functionGettingTested,
|
||||
int loopIndex,
|
||||
int iterationId,
|
||||
long runtime,
|
||||
byte[] returnValue,
|
||||
String verificationType
|
||||
) {
|
||||
if (OUTPUT_FILE == null || OUTPUT_FILE.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
ensureDbInitialized();
|
||||
if (dbConnection == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
String sql = "INSERT INTO test_results " +
|
||||
"(test_module_path, test_class_name, test_function_name, function_getting_tested, " +
|
||||
"loop_index, iteration_id, runtime, return_value, verification_type) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
|
||||
|
||||
try (PreparedStatement stmt = dbConnection.prepareStatement(sql)) {
|
||||
stmt.setString(1, testModulePath);
|
||||
stmt.setString(2, testClassName);
|
||||
stmt.setString(3, testFunctionName);
|
||||
stmt.setString(4, functionGettingTested);
|
||||
stmt.setInt(5, loopIndex);
|
||||
stmt.setInt(6, iterationId);
|
||||
stmt.setLong(7, runtime);
|
||||
stmt.setBytes(8, returnValue);
|
||||
stmt.setString(9, verificationType);
|
||||
stmt.executeUpdate();
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
System.err.println("CodeflashHelper: Failed to write to SQLite: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensure the database is initialized.
|
||||
*/
|
||||
private static void ensureDbInitialized() {
|
||||
if (dbInitialized) {
|
||||
return;
|
||||
}
|
||||
dbInitialized = true;
|
||||
|
||||
if (OUTPUT_FILE == null || OUTPUT_FILE.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Load SQLite JDBC driver
|
||||
Class.forName("org.sqlite.JDBC");
|
||||
|
||||
// Create parent directories if needed
|
||||
File dbFile = new File(OUTPUT_FILE);
|
||||
File parentDir = dbFile.getParentFile();
|
||||
if (parentDir != null && !parentDir.exists()) {
|
||||
parentDir.mkdirs();
|
||||
}
|
||||
|
||||
// Connect to database
|
||||
dbConnection = DriverManager.getConnection("jdbc:sqlite:" + OUTPUT_FILE);
|
||||
|
||||
// Create table if not exists
|
||||
String createTableSql = "CREATE TABLE IF NOT EXISTS test_results (" +
|
||||
"test_module_path TEXT, " +
|
||||
"test_class_name TEXT, " +
|
||||
"test_function_name TEXT, " +
|
||||
"function_getting_tested TEXT, " +
|
||||
"loop_index INTEGER, " +
|
||||
"iteration_id INTEGER, " +
|
||||
"runtime INTEGER, " +
|
||||
"return_value BLOB, " +
|
||||
"verification_type TEXT" +
|
||||
")";
|
||||
|
||||
try (Statement stmt = dbConnection.createStatement()) {
|
||||
stmt.execute(createTableSql);
|
||||
}
|
||||
|
||||
// Register shutdown hook to close connection
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||
try {
|
||||
if (dbConnection != null && !dbConnection.isClosed()) {
|
||||
dbConnection.close();
|
||||
}
|
||||
} catch (SQLException e) {
|
||||
// Ignore
|
||||
}
|
||||
}));
|
||||
|
||||
} catch (ClassNotFoundException e) {
|
||||
System.err.println("CodeflashHelper: SQLite JDBC driver not found. " +
|
||||
"Add sqlite-jdbc to your dependencies. Timing will still be captured via stdout.");
|
||||
} catch (SQLException e) {
|
||||
System.err.println("CodeflashHelper: Failed to initialize SQLite: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse int with default value.
|
||||
*/
|
||||
private static int parseIntOrDefault(String value, int defaultValue) {
|
||||
if (value == null || value.isEmpty()) {
|
||||
return defaultValue;
|
||||
}
|
||||
try {
|
||||
return Integer.parseInt(value);
|
||||
} catch (NumberFormatException e) {
|
||||
return defaultValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -98,6 +98,12 @@ class JavaSupport(LanguageSupport):
|
|||
"""Find all optimizable functions in a Java file."""
|
||||
return discover_functions(file_path, filter_criteria, self._analyzer)
|
||||
|
||||
def discover_functions_from_source(
|
||||
self, source: str, file_path: Path | None = None, filter_criteria: FunctionFilterCriteria | None = None
|
||||
) -> list[FunctionInfo]:
|
||||
"""Find all optimizable functions in Java source code."""
|
||||
return discover_functions_from_source(source, file_path, filter_criteria, self._analyzer)
|
||||
|
||||
def discover_tests(
|
||||
self, test_root: Path, source_functions: Sequence[FunctionInfo]
|
||||
) -> dict[str, list[TestInfo]]:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import uuid
|
||||
|
|
@ -57,6 +58,7 @@ def run_behavioral_tests(
|
|||
"""Run behavioral tests for Java code.
|
||||
|
||||
This runs tests and captures behavior (inputs/outputs) for verification.
|
||||
For Java, verification is based on JUnit test pass/fail results.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object or list of test file paths.
|
||||
|
|
@ -68,20 +70,17 @@ def run_behavioral_tests(
|
|||
candidate_index: Index of the candidate being tested.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result, coverage_path, config_path).
|
||||
Tuple of (result_xml_path, subprocess_result, coverage_path, config_path).
|
||||
|
||||
"""
|
||||
project_root = project_root or cwd
|
||||
|
||||
# Generate unique result file path
|
||||
result_id = uuid.uuid4().hex[:8]
|
||||
result_file = Path(tempfile.gettempdir()) / f"codeflash_java_behavior_{result_id}.db"
|
||||
|
||||
# Set environment variables for CodeFlash runtime
|
||||
# Set environment variables for timing instrumentation
|
||||
run_env = os.environ.copy()
|
||||
run_env.update(test_env)
|
||||
run_env["CODEFLASH_RESULT_FILE"] = str(result_file)
|
||||
run_env["CODEFLASH_LOOP_INDEX"] = "1" # Single loop for behavior tests
|
||||
run_env["CODEFLASH_MODE"] = "behavior"
|
||||
run_env["CODEFLASH_TEST_ITERATION"] = str(candidate_index)
|
||||
|
||||
# Run Maven tests
|
||||
result = _run_maven_tests(
|
||||
|
|
@ -89,9 +88,14 @@ def run_behavioral_tests(
|
|||
test_paths,
|
||||
run_env,
|
||||
timeout=timeout or 300,
|
||||
mode="behavior",
|
||||
)
|
||||
|
||||
return result_file, result, None, None
|
||||
# Find or create the JUnit XML results file
|
||||
surefire_dir = project_root / "target" / "surefire-reports"
|
||||
result_xml_path = _get_combined_junit_xml(surefire_dir, candidate_index)
|
||||
|
||||
return result_xml_path, result, None, None
|
||||
|
||||
|
||||
def run_benchmarking_tests(
|
||||
|
|
@ -101,12 +105,15 @@ def run_benchmarking_tests(
|
|||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
min_loops: int = 5,
|
||||
max_loops: int = 100_000,
|
||||
max_loops: int = 100,
|
||||
target_duration_seconds: float = 10.0,
|
||||
) -> tuple[Path, Any]:
|
||||
"""Run benchmarking tests for Java code.
|
||||
|
||||
This runs tests with performance measurement.
|
||||
This runs tests multiple times with performance measurement.
|
||||
The instrumented tests print timing markers that are parsed from stdout:
|
||||
Start: !$######testModule:testClass:funcName:loopIndex:iterationId######$!
|
||||
End: !######testModule:testClass:funcName:loopIndex:iterationId:durationNs######!
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object or list of test file paths.
|
||||
|
|
@ -119,33 +126,182 @@ def run_benchmarking_tests(
|
|||
target_duration_seconds: Target duration for benchmarking in seconds.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result).
|
||||
Tuple of (result_file_path, subprocess_result with aggregated stdout).
|
||||
|
||||
"""
|
||||
import time
|
||||
|
||||
project_root = project_root or cwd
|
||||
|
||||
# Generate unique result file path
|
||||
result_id = uuid.uuid4().hex[:8]
|
||||
result_file = Path(tempfile.gettempdir()) / f"codeflash_java_benchmark_{result_id}.db"
|
||||
# Collect stdout from all loops
|
||||
all_stdout = []
|
||||
all_stderr = []
|
||||
total_start_time = time.time()
|
||||
loop_count = 0
|
||||
last_result = None
|
||||
|
||||
# Set environment variables
|
||||
run_env = os.environ.copy()
|
||||
run_env.update(test_env)
|
||||
run_env["CODEFLASH_RESULT_FILE"] = str(result_file)
|
||||
run_env["CODEFLASH_MODE"] = "benchmark"
|
||||
run_env["CODEFLASH_MIN_LOOPS"] = str(min_loops)
|
||||
run_env["CODEFLASH_MAX_LOOPS"] = str(max_loops)
|
||||
run_env["CODEFLASH_TARGET_DURATION"] = str(target_duration_seconds)
|
||||
# Run multiple loops until we hit target duration or max loops
|
||||
for loop_idx in range(1, max_loops + 1):
|
||||
# Set environment variables for this loop
|
||||
run_env = os.environ.copy()
|
||||
run_env.update(test_env)
|
||||
run_env["CODEFLASH_LOOP_INDEX"] = str(loop_idx)
|
||||
run_env["CODEFLASH_MODE"] = "performance"
|
||||
run_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
|
||||
# Run Maven tests
|
||||
result = _run_maven_tests(
|
||||
project_root,
|
||||
test_paths,
|
||||
run_env,
|
||||
timeout=timeout or 600, # Longer timeout for benchmarks
|
||||
# Run Maven tests for this loop
|
||||
result = _run_maven_tests(
|
||||
project_root,
|
||||
test_paths,
|
||||
run_env,
|
||||
timeout=timeout or 120, # Per-loop timeout
|
||||
mode="performance",
|
||||
)
|
||||
|
||||
last_result = result
|
||||
loop_count = loop_idx
|
||||
|
||||
# Collect stdout/stderr
|
||||
if result.stdout:
|
||||
all_stdout.append(result.stdout)
|
||||
if result.stderr:
|
||||
all_stderr.append(result.stderr)
|
||||
|
||||
# Check if we've hit the target duration
|
||||
elapsed = time.time() - total_start_time
|
||||
if loop_idx >= min_loops and elapsed >= target_duration_seconds:
|
||||
logger.debug(
|
||||
"Stopping benchmark after %d loops (%.2fs elapsed, target: %.2fs)",
|
||||
loop_idx,
|
||||
elapsed,
|
||||
target_duration_seconds,
|
||||
)
|
||||
break
|
||||
|
||||
# Check if tests failed - don't continue looping
|
||||
if result.returncode != 0:
|
||||
logger.warning("Tests failed in loop %d, stopping benchmark", loop_idx)
|
||||
break
|
||||
|
||||
# Create a combined result with all stdout
|
||||
combined_stdout = "\n".join(all_stdout)
|
||||
combined_stderr = "\n".join(all_stderr)
|
||||
|
||||
logger.debug(
|
||||
"Completed %d benchmark loops in %.2fs",
|
||||
loop_count,
|
||||
time.time() - total_start_time,
|
||||
)
|
||||
|
||||
return result_file, result
|
||||
# Create a combined subprocess result
|
||||
combined_result = subprocess.CompletedProcess(
|
||||
args=last_result.args if last_result else ["mvn", "test"],
|
||||
returncode=last_result.returncode if last_result else -1,
|
||||
stdout=combined_stdout,
|
||||
stderr=combined_stderr,
|
||||
)
|
||||
|
||||
# Find or create the JUnit XML results file (from last run)
|
||||
surefire_dir = project_root / "target" / "surefire-reports"
|
||||
result_xml_path = _get_combined_junit_xml(surefire_dir, -1) # Use -1 for benchmark
|
||||
|
||||
return result_xml_path, combined_result
|
||||
|
||||
|
||||
def _get_combined_junit_xml(surefire_dir: Path, candidate_index: int) -> Path:
|
||||
"""Get or create a combined JUnit XML file from Surefire reports.
|
||||
|
||||
Args:
|
||||
surefire_dir: Directory containing Surefire reports.
|
||||
candidate_index: Index for unique naming.
|
||||
|
||||
Returns:
|
||||
Path to the combined JUnit XML file.
|
||||
|
||||
"""
|
||||
# Create a temp file for the combined results
|
||||
result_id = uuid.uuid4().hex[:8]
|
||||
result_xml_path = Path(tempfile.gettempdir()) / f"codeflash_java_results_{candidate_index}_{result_id}.xml"
|
||||
|
||||
if not surefire_dir.exists():
|
||||
# Create an empty results file
|
||||
_write_empty_junit_xml(result_xml_path)
|
||||
return result_xml_path
|
||||
|
||||
# Find all TEST-*.xml files
|
||||
xml_files = list(surefire_dir.glob("TEST-*.xml"))
|
||||
|
||||
if not xml_files:
|
||||
_write_empty_junit_xml(result_xml_path)
|
||||
return result_xml_path
|
||||
|
||||
if len(xml_files) == 1:
|
||||
# Copy the single file
|
||||
shutil.copy(xml_files[0], result_xml_path)
|
||||
return result_xml_path
|
||||
|
||||
# Combine multiple XML files into one
|
||||
_combine_junit_xml_files(xml_files, result_xml_path)
|
||||
return result_xml_path
|
||||
|
||||
|
||||
def _write_empty_junit_xml(path: Path) -> None:
|
||||
"""Write an empty JUnit XML results file."""
|
||||
xml_content = '''<?xml version="1.0" encoding="UTF-8"?>
|
||||
<testsuite name="NoTests" tests="0" failures="0" errors="0" skipped="0" time="0">
|
||||
</testsuite>
|
||||
'''
|
||||
path.write_text(xml_content, encoding="utf-8")
|
||||
|
||||
|
||||
def _combine_junit_xml_files(xml_files: list[Path], output_path: Path) -> None:
|
||||
"""Combine multiple JUnit XML files into one.
|
||||
|
||||
Args:
|
||||
xml_files: List of XML files to combine.
|
||||
output_path: Path for the combined output.
|
||||
|
||||
"""
|
||||
total_tests = 0
|
||||
total_failures = 0
|
||||
total_errors = 0
|
||||
total_skipped = 0
|
||||
total_time = 0.0
|
||||
all_testcases = []
|
||||
|
||||
for xml_file in xml_files:
|
||||
try:
|
||||
tree = ET.parse(xml_file)
|
||||
root = tree.getroot()
|
||||
|
||||
# Get testsuite attributes
|
||||
total_tests += int(root.get("tests", 0))
|
||||
total_failures += int(root.get("failures", 0))
|
||||
total_errors += int(root.get("errors", 0))
|
||||
total_skipped += int(root.get("skipped", 0))
|
||||
total_time += float(root.get("time", 0))
|
||||
|
||||
# Collect all testcases
|
||||
for testcase in root.findall(".//testcase"):
|
||||
all_testcases.append(testcase)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse %s: %s", xml_file, e)
|
||||
|
||||
# Create combined XML
|
||||
combined_root = ET.Element("testsuite")
|
||||
combined_root.set("name", "CombinedTests")
|
||||
combined_root.set("tests", str(total_tests))
|
||||
combined_root.set("failures", str(total_failures))
|
||||
combined_root.set("errors", str(total_errors))
|
||||
combined_root.set("skipped", str(total_skipped))
|
||||
combined_root.set("time", str(total_time))
|
||||
|
||||
for testcase in all_testcases:
|
||||
combined_root.append(testcase)
|
||||
|
||||
tree = ET.ElementTree(combined_root)
|
||||
tree.write(output_path, encoding="unicode", xml_declaration=True)
|
||||
|
||||
|
||||
def _run_maven_tests(
|
||||
|
|
@ -153,6 +309,7 @@ def _run_maven_tests(
|
|||
test_paths: Any,
|
||||
env: dict[str, str],
|
||||
timeout: int = 300,
|
||||
mode: str = "behavior",
|
||||
) -> subprocess.CompletedProcess:
|
||||
"""Run Maven tests with Surefire.
|
||||
|
||||
|
|
@ -161,6 +318,7 @@ def _run_maven_tests(
|
|||
test_paths: Test files or classes to run.
|
||||
env: Environment variables.
|
||||
timeout: Maximum execution time in seconds.
|
||||
mode: Testing mode - "behavior" or "performance".
|
||||
|
||||
Returns:
|
||||
CompletedProcess with test results.
|
||||
|
|
@ -177,7 +335,7 @@ def _run_maven_tests(
|
|||
)
|
||||
|
||||
# Build test filter
|
||||
test_filter = _build_test_filter(test_paths)
|
||||
test_filter = _build_test_filter(test_paths, mode=mode)
|
||||
|
||||
# Build Maven command
|
||||
cmd = [mvn, "test", "-fae"] # Fail at end to run all tests
|
||||
|
|
@ -185,6 +343,8 @@ def _run_maven_tests(
|
|||
if test_filter:
|
||||
cmd.append(f"-Dtest={test_filter}")
|
||||
|
||||
logger.debug("Running Maven command: %s in %s", " ".join(cmd), project_root)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
|
|
@ -215,11 +375,12 @@ def _run_maven_tests(
|
|||
)
|
||||
|
||||
|
||||
def _build_test_filter(test_paths: Any) -> str:
|
||||
def _build_test_filter(test_paths: Any, mode: str = "behavior") -> str:
|
||||
"""Build a Maven Surefire test filter from test paths.
|
||||
|
||||
Args:
|
||||
test_paths: Test files, classes, or methods to include.
|
||||
mode: Testing mode - "behavior" or "performance".
|
||||
|
||||
Returns:
|
||||
Surefire test filter string.
|
||||
|
|
@ -243,7 +404,21 @@ def _build_test_filter(test_paths: Any) -> str:
|
|||
|
||||
# Handle TestFiles object (has test_files attribute)
|
||||
if hasattr(test_paths, "test_files"):
|
||||
return _build_test_filter(list(test_paths.test_files))
|
||||
filters = []
|
||||
for test_file in test_paths.test_files:
|
||||
# For performance mode, use benchmarking_file_path
|
||||
if mode == "performance":
|
||||
if hasattr(test_file, "benchmarking_file_path") and test_file.benchmarking_file_path:
|
||||
class_name = _path_to_class_name(test_file.benchmarking_file_path)
|
||||
if class_name:
|
||||
filters.append(class_name)
|
||||
else:
|
||||
# For behavior mode, use instrumented_behavior_file_path
|
||||
if hasattr(test_file, "instrumented_behavior_file_path") and test_file.instrumented_behavior_file_path:
|
||||
class_name = _path_to_class_name(test_file.instrumented_behavior_file_path)
|
||||
if class_name:
|
||||
filters.append(class_name)
|
||||
return ",".join(filters) if filters else ""
|
||||
|
||||
return ""
|
||||
|
||||
|
|
@ -263,19 +438,31 @@ def _path_to_class_name(path: Path) -> str | None:
|
|||
|
||||
# Try to extract package from path
|
||||
# e.g., src/test/java/com/example/CalculatorTest.java -> com.example.CalculatorTest
|
||||
parts = path.parts
|
||||
parts = list(path.parts)
|
||||
|
||||
# Find 'java' in the path and take everything after
|
||||
try:
|
||||
java_idx = parts.index("java")
|
||||
class_parts = parts[java_idx + 1 :]
|
||||
# Look for standard Maven/Gradle source directories
|
||||
# Find 'java' that comes after 'main' or 'test'
|
||||
java_idx = None
|
||||
for i, part in enumerate(parts):
|
||||
if part == "java" and i > 0 and parts[i - 1] in ("main", "test"):
|
||||
java_idx = i
|
||||
break
|
||||
|
||||
# If no standard Maven structure, find the last 'java' in path
|
||||
if java_idx is None:
|
||||
for i in range(len(parts) - 1, -1, -1):
|
||||
if parts[i] == "java":
|
||||
java_idx = i
|
||||
break
|
||||
|
||||
if java_idx is not None:
|
||||
class_parts = parts[java_idx + 1:]
|
||||
# Remove .java extension from last part
|
||||
class_parts = list(class_parts)
|
||||
class_parts[-1] = class_parts[-1].replace(".java", "")
|
||||
return ".".join(class_parts)
|
||||
except ValueError:
|
||||
# No 'java' directory, just use the file name
|
||||
return path.stem
|
||||
|
||||
# Fallback: just use the file name
|
||||
return path.stem
|
||||
|
||||
|
||||
def run_tests(
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ from codeflash.context import code_context_extractor
|
|||
from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions
|
||||
from codeflash.discovery.functions_to_optimize import was_function_previously_optimized
|
||||
from codeflash.either import Failure, Success, is_successful
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.languages import is_java, is_python
|
||||
from codeflash.languages.base import FunctionInfo, Language
|
||||
from codeflash.languages.current import current_language_support, is_typescript
|
||||
from codeflash.languages.javascript.module_system import detect_module_system
|
||||
|
|
@ -577,17 +577,29 @@ class FunctionOptimizer:
|
|||
|
||||
logger.debug(f"[PIPELINE] Processing {count_tests} generated tests")
|
||||
for i, generated_test in enumerate(generated_tests.generated_tests):
|
||||
behavior_path = generated_test.behavior_file_path
|
||||
perf_path = generated_test.perf_file_path
|
||||
|
||||
# For Java, fix paths to match package structure
|
||||
if is_java():
|
||||
behavior_path, perf_path = self._fix_java_test_paths(
|
||||
generated_test.instrumented_behavior_test_source,
|
||||
generated_test.instrumented_perf_test_source,
|
||||
)
|
||||
generated_test.behavior_file_path = behavior_path
|
||||
generated_test.perf_file_path = perf_path
|
||||
|
||||
logger.debug(
|
||||
f"[PIPELINE] Test {i + 1}: behavior_path={generated_test.behavior_file_path}, perf_path={generated_test.perf_file_path}"
|
||||
f"[PIPELINE] Test {i + 1}: behavior_path={behavior_path}, perf_path={perf_path}"
|
||||
)
|
||||
|
||||
with generated_test.behavior_file_path.open("w", encoding="utf8") as f:
|
||||
with behavior_path.open("w", encoding="utf8") as f:
|
||||
f.write(generated_test.instrumented_behavior_test_source)
|
||||
logger.debug(f"[PIPELINE] Wrote behavioral test to {generated_test.behavior_file_path}")
|
||||
logger.debug(f"[PIPELINE] Wrote behavioral test to {behavior_path}")
|
||||
|
||||
with generated_test.perf_file_path.open("w", encoding="utf8") as f:
|
||||
with perf_path.open("w", encoding="utf8") as f:
|
||||
f.write(generated_test.instrumented_perf_test_source)
|
||||
logger.debug(f"[PIPELINE] Wrote perf test to {generated_test.perf_file_path}")
|
||||
logger.debug(f"[PIPELINE] Wrote perf test to {perf_path}")
|
||||
|
||||
# File paths are expected to be absolute - resolved at their source (CLI, TestConfig, etc.)
|
||||
test_file_obj = TestFile(
|
||||
|
|
@ -640,6 +652,55 @@ class FunctionOptimizer:
|
|||
)
|
||||
)
|
||||
|
||||
def _fix_java_test_paths(
|
||||
self, behavior_source: str, perf_source: str
|
||||
) -> tuple[Path, Path]:
|
||||
"""Fix Java test file paths to match package structure.
|
||||
|
||||
Java requires test files to be in directories matching their package.
|
||||
This method extracts the package and class from the generated tests
|
||||
and returns correct paths.
|
||||
|
||||
Args:
|
||||
behavior_source: Source code of the behavior test.
|
||||
perf_source: Source code of the performance test.
|
||||
|
||||
Returns:
|
||||
Tuple of (behavior_path, perf_path) with correct package structure.
|
||||
|
||||
"""
|
||||
import re
|
||||
|
||||
# Extract package from behavior source
|
||||
package_match = re.search(r'^\s*package\s+([\w.]+)\s*;', behavior_source, re.MULTILINE)
|
||||
package_name = package_match.group(1) if package_match else ""
|
||||
|
||||
# Extract class name from behavior source
|
||||
class_match = re.search(r'\bclass\s+(\w+)', behavior_source)
|
||||
behavior_class = class_match.group(1) if class_match else "GeneratedTest"
|
||||
|
||||
# Extract class name from perf source
|
||||
perf_class_match = re.search(r'\bclass\s+(\w+)', perf_source)
|
||||
perf_class = perf_class_match.group(1) if perf_class_match else "GeneratedPerfTest"
|
||||
|
||||
# Build paths with package structure
|
||||
test_dir = self.test_cfg.tests_root
|
||||
|
||||
if package_name:
|
||||
package_path = package_name.replace(".", "/")
|
||||
behavior_path = test_dir / package_path / f"{behavior_class}.java"
|
||||
perf_path = test_dir / package_path / f"{perf_class}.java"
|
||||
else:
|
||||
behavior_path = test_dir / f"{behavior_class}.java"
|
||||
perf_path = test_dir / f"{perf_class}.java"
|
||||
|
||||
# Create directories if needed
|
||||
behavior_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
perf_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.debug(f"[JAVA] Fixed paths: behavior={behavior_path}, perf={perf_path}")
|
||||
return behavior_path, perf_path
|
||||
|
||||
# note: this isn't called by the lsp, only called by cli
|
||||
def optimize_function(self) -> Result[BestOptimization, str]:
|
||||
initialization_result = self.can_be_optimized()
|
||||
|
|
|
|||
|
|
@ -204,7 +204,15 @@ def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | Origin
|
|||
|
||||
|
||||
def coverage_critic(original_code_coverage: CoverageData | None) -> bool:
|
||||
"""Check if the coverage meets the threshold."""
|
||||
"""Check if the coverage meets the threshold.
|
||||
|
||||
For languages without coverage support (like Java), returns True if no coverage data is available.
|
||||
"""
|
||||
from codeflash.languages import is_java, is_javascript
|
||||
|
||||
if original_code_coverage:
|
||||
return original_code_coverage.coverage >= COVERAGE_THRESHOLD
|
||||
# For Java/JavaScript, coverage is not implemented yet, so skip the check
|
||||
if is_java() or is_javascript():
|
||||
return True
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from codeflash.code_utils.code_utils import (
|
|||
module_name_from_file_path,
|
||||
)
|
||||
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
|
||||
from codeflash.languages import is_javascript
|
||||
from codeflash.languages import is_java, is_javascript
|
||||
from codeflash.models.models import (
|
||||
ConcurrencyMetrics,
|
||||
FunctionTestInvocation,
|
||||
|
|
@ -128,7 +128,7 @@ def parse_concurrency_metrics(test_results: TestResults, function_name: str) ->
|
|||
|
||||
|
||||
def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> Path | None:
|
||||
"""Resolve test file path from pytest's test class path.
|
||||
"""Resolve test file path from pytest's test class path or Java class path.
|
||||
|
||||
This function handles various cases where pytest's classname in JUnit XML
|
||||
includes parent directories that may already be part of base_dir.
|
||||
|
|
@ -136,6 +136,7 @@ def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> P
|
|||
Args:
|
||||
test_class_path: The full class path from pytest (e.g., "project.tests.test_file.TestClass")
|
||||
or a file path from Jest (e.g., "tests/test_file.test.js")
|
||||
or a Java class path (e.g., "com.example.AlgorithmsTest")
|
||||
base_dir: The base directory for tests (tests project root)
|
||||
|
||||
Returns:
|
||||
|
|
@ -147,6 +148,35 @@ def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> P
|
|||
>>> # Should find: /path/to/tests/unittest/test_file.py
|
||||
|
||||
"""
|
||||
# Handle Java class paths (convert dots to path and add .java extension)
|
||||
# Java class paths look like "com.example.TestClass" and should map to
|
||||
# src/test/java/com/example/TestClass.java
|
||||
if is_java():
|
||||
# Convert dots to path separators
|
||||
relative_path = test_class_path.replace(".", "/") + ".java"
|
||||
|
||||
# Try various locations
|
||||
# 1. Directly under base_dir
|
||||
potential_path = base_dir / relative_path
|
||||
if potential_path.exists():
|
||||
return potential_path
|
||||
|
||||
# 2. Under src/test/java relative to project root
|
||||
project_root = base_dir.parent if base_dir.name == "java" else base_dir
|
||||
while project_root.name not in ("", "/") and not (project_root / "pom.xml").exists():
|
||||
project_root = project_root.parent
|
||||
if (project_root / "pom.xml").exists():
|
||||
potential_path = project_root / "src" / "test" / "java" / relative_path
|
||||
if potential_path.exists():
|
||||
return potential_path
|
||||
|
||||
# 3. Search for the file in base_dir and its subdirectories
|
||||
file_name = test_class_path.split(".")[-1] + ".java"
|
||||
for java_file in base_dir.rglob(file_name):
|
||||
return java_file
|
||||
|
||||
return None
|
||||
|
||||
# Handle file paths (contain slashes and extensions like .js/.ts)
|
||||
if "/" in test_class_path or "\\" in test_class_path:
|
||||
# This is a file path, not a Python module path
|
||||
|
|
@ -997,6 +1027,19 @@ def parse_test_xml(
|
|||
end_matches[groups] = match
|
||||
|
||||
if not begin_matches or not begin_matches:
|
||||
# For Java tests, use the JUnit XML time attribute for runtime
|
||||
runtime_from_xml = None
|
||||
if is_java():
|
||||
try:
|
||||
# JUnit XML time is in seconds, convert to nanoseconds
|
||||
# Use a minimum of 1000ns (1 microsecond) for any successful test
|
||||
# to avoid 0 runtime being treated as "no runtime"
|
||||
test_time = float(testcase.time) if hasattr(testcase, 'time') and testcase.time else 0.0
|
||||
runtime_from_xml = max(int(test_time * 1_000_000_000), 1000)
|
||||
except (ValueError, TypeError):
|
||||
# If we can't get time from XML, use 1 microsecond as minimum
|
||||
runtime_from_xml = 1000
|
||||
|
||||
test_results.add(
|
||||
FunctionTestInvocation(
|
||||
loop_index=loop_index,
|
||||
|
|
@ -1008,7 +1051,7 @@ def parse_test_xml(
|
|||
iteration_id="",
|
||||
),
|
||||
file_name=test_file_path,
|
||||
runtime=None,
|
||||
runtime=runtime_from_xml,
|
||||
test_framework=test_config.test_framework,
|
||||
did_pass=result,
|
||||
test_type=test_type,
|
||||
|
|
|
|||
|
|
@ -9,9 +9,16 @@ from pydantic.dataclasses import dataclass
|
|||
from codeflash.languages import current_language_support, is_java, is_javascript
|
||||
|
||||
|
||||
def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit") -> Path:
|
||||
def get_test_file_path(
|
||||
test_dir: Path,
|
||||
function_name: str,
|
||||
iteration: int = 0,
|
||||
test_type: str = "unit",
|
||||
package_name: str | None = None,
|
||||
class_name: str | None = None,
|
||||
) -> Path:
|
||||
assert test_type in {"unit", "inspired", "replay", "perf"}
|
||||
function_name = function_name.replace(".", "_")
|
||||
function_name_safe = function_name.replace(".", "_")
|
||||
# Use appropriate file extension based on language
|
||||
if is_javascript():
|
||||
extension = current_language_support().get_test_file_suffix()
|
||||
|
|
@ -19,9 +26,25 @@ def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, t
|
|||
extension = ".java"
|
||||
else:
|
||||
extension = ".py"
|
||||
path = test_dir / f"test_{function_name}__{test_type}_test_{iteration}{extension}"
|
||||
|
||||
if is_java() and package_name:
|
||||
# For Java, create package directory structure
|
||||
# e.g., com.example -> com/example/
|
||||
package_path = package_name.replace(".", "/")
|
||||
java_class_name = class_name or f"{function_name_safe.title()}Test"
|
||||
# Add suffix to avoid conflicts
|
||||
if test_type == "perf":
|
||||
java_class_name = f"{java_class_name}__perfonlyinstrumented"
|
||||
elif test_type == "unit":
|
||||
java_class_name = f"{java_class_name}__perfinstrumented"
|
||||
path = test_dir / package_path / f"{java_class_name}{extension}"
|
||||
# Create package directory if needed
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
else:
|
||||
path = test_dir / f"test_{function_name_safe}__{test_type}_test_{iteration}{extension}"
|
||||
|
||||
if path.exists():
|
||||
return get_test_file_path(test_dir, function_name, iteration + 1, test_type)
|
||||
return get_test_file_path(test_dir, function_name, iteration + 1, test_type, package_name, class_name)
|
||||
return path
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
|
|||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
|
||||
from codeflash.languages import is_javascript
|
||||
from codeflash.languages import is_java, is_javascript
|
||||
from codeflash.verification.verification_utils import ModifyInspiredTests, delete_multiple_if_name_main
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -98,6 +98,29 @@ def generate_tests(
|
|||
)
|
||||
|
||||
logger.debug(f"Instrumented JS/TS tests locally for {func_name}")
|
||||
elif is_java():
|
||||
from codeflash.languages.java.instrumentation import instrument_generated_java_test
|
||||
|
||||
func_name = function_to_optimize.function_name
|
||||
qualified_name = function_to_optimize.qualified_name
|
||||
|
||||
# Instrument for behavior verification (renames class)
|
||||
instrumented_behavior_test_source = instrument_generated_java_test(
|
||||
test_code=generated_test_source,
|
||||
function_name=func_name,
|
||||
qualified_name=qualified_name,
|
||||
mode="behavior",
|
||||
)
|
||||
|
||||
# Instrument for performance measurement (adds timing markers)
|
||||
instrumented_perf_test_source = instrument_generated_java_test(
|
||||
test_code=generated_test_source,
|
||||
function_name=func_name,
|
||||
qualified_name=qualified_name,
|
||||
mode="performance",
|
||||
)
|
||||
|
||||
logger.debug(f"Instrumented Java tests locally for {func_name}")
|
||||
else:
|
||||
# Python: instrumentation is done by aiservice, just replace temp dir placeholders
|
||||
instrumented_behavior_test_source = instrumented_behavior_test_source.replace(
|
||||
|
|
|
|||
1095
docs/java-support-architecture.md
Normal file
1095
docs/java-support-architecture.md
Normal file
File diff suppressed because it is too large
Load diff
17
uv.lock
17
uv.lock
|
|
@ -438,6 +438,7 @@ dependencies = [
|
|||
{ name = "tomlkit" },
|
||||
{ name = "tree-sitter", version = "0.23.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
|
||||
{ name = "tree-sitter", version = "0.25.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" },
|
||||
{ name = "tree-sitter-java" },
|
||||
{ name = "tree-sitter-javascript", version = "0.23.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
|
||||
{ name = "tree-sitter-javascript", version = "0.25.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" },
|
||||
{ name = "tree-sitter-typescript" },
|
||||
|
|
@ -526,6 +527,7 @@ requires-dist = [
|
|||
{ name = "sentry-sdk", specifier = ">=1.40.6,<3.0.0" },
|
||||
{ name = "tomlkit", specifier = ">=0.11.7" },
|
||||
{ name = "tree-sitter", specifier = ">=0.23.0" },
|
||||
{ name = "tree-sitter-java", specifier = ">=0.23.0" },
|
||||
{ name = "tree-sitter-javascript", specifier = ">=0.23.0" },
|
||||
{ name = "tree-sitter-typescript", specifier = ">=0.23.0" },
|
||||
{ name = "unidiff", specifier = ">=0.7.4" },
|
||||
|
|
@ -5222,6 +5224,21 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/a6/6e/e64621037357acb83d912276ffd30a859ef117f9c680f2e3cb955f47c680/tree_sitter-0.25.2-cp314-cp314-win_arm64.whl", hash = "sha256:b8d4429954a3beb3e844e2872610d2a4800ba4eb42bb1990c6a4b1949b18459f", size = 117470, upload-time = "2025-09-25T17:37:58.431Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-java"
|
||||
version = "0.23.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fa/dc/eb9c8f96304e5d8ae1663126d89967a622a80937ad2909903569ccb7ec8f/tree_sitter_java-0.23.5.tar.gz", hash = "sha256:f5cd57b8f1270a7f0438878750d02ccc79421d45cca65ff284f1527e9ef02e38", size = 138121, upload-time = "2024-12-21T18:24:26.936Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/67/21/b3399780b440e1567a11d384d0ebb1aea9b642d0d98becf30fa55c0e3a3b/tree_sitter_java-0.23.5-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:355ce0308672d6f7013ec913dee4a0613666f4cda9044a7824240d17f38209df", size = 58926, upload-time = "2024-12-21T18:24:12.53Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/57/ef/6406b444e2a93bc72a04e802f4107e9ecf04b8de4a5528830726d210599c/tree_sitter_java-0.23.5-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:24acd59c4720dedad80d548fe4237e43ef2b7a4e94c8549b0ca6e4c4d7bf6e69", size = 62288, upload-time = "2024-12-21T18:24:14.634Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4e/6c/74b1c150d4f69c291ab0b78d5dd1b59712559bbe7e7daf6d8466d483463f/tree_sitter_java-0.23.5-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9401e7271f0b333df39fc8a8336a0caf1b891d9a2b89ddee99fae66b794fc5b7", size = 85533, upload-time = "2024-12-21T18:24:16.695Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/29/09/e0d08f5c212062fd046db35c1015a2621c2631bc8b4aae5740d7adb276ad/tree_sitter_java-0.23.5-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:370b204b9500b847f6d0c5ad584045831cee69e9a3e4d878535d39e4a7e4c4f1", size = 84033, upload-time = "2024-12-21T18:24:18.758Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/43/56/7d06b23ddd09bde816a131aa504ee11a1bbe87c6b62ab9b2ed23849a3382/tree_sitter_java-0.23.5-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:aae84449e330363b55b14a2af0585e4e0dae75eb64ea509b7e5b0e1de536846a", size = 82564, upload-time = "2024-12-21T18:24:20.493Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/da/d6/0528c7e1e88a18221dbd8ccee3825bf274b1fa300f745fd74eb343878043/tree_sitter_java-0.23.5-cp39-abi3-win_amd64.whl", hash = "sha256:1ee45e790f8d31d416bc84a09dac2e2c6bc343e89b8a2e1d550513498eedfde7", size = 60650, upload-time = "2024-12-21T18:24:22.902Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/57/5bab54d23179350356515526fff3cc0f3ac23bfbc1a1d518a15978d4880e/tree_sitter_java-0.23.5-cp39-abi3-win_arm64.whl", hash = "sha256:402efe136104c5603b429dc26c7e75ae14faaca54cfd319ecc41c8f2534750f4", size = 59059, upload-time = "2024-12-21T18:24:24.934Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tree-sitter-javascript"
|
||||
version = "0.23.1"
|
||||
|
|
|
|||
Loading…
Reference in a new issue