feat: sync CLI from main with Java additions

Adds subagent mode, config display, and main's CLI refactoring.
Preserves omni-java's NullHighlighter, Java test root detection,
and Java project root detection (pom.xml/build.gradle).
This commit is contained in:
Kevin Turcios 2026-03-02 15:24:33 -05:00
parent d518ad2d91
commit 2299d26ae5
3 changed files with 234 additions and 61 deletions

View file

@ -130,9 +130,18 @@ def parse_args() -> Namespace:
"--reset-config", action="store_true", help="Remove codeflash configuration from project config file."
)
parser.add_argument("-y", "--yes", action="store_true", help="Skip confirmation prompts (useful for CI/scripts).")
parser.add_argument(
"--subagent",
action="store_true",
help="Subagent mode: skip all interactive prompts with sensible defaults. Designed for AI agent integrations.",
)
args, unknown_args = parser.parse_known_args()
sys.argv[:] = [sys.argv[0], *unknown_args]
if args.subagent:
args.yes = True
args.no_pr = True
args.worktree = True
return process_and_validate_cmd_args(args)
@ -237,7 +246,18 @@ def process_pyproject_config(args: Namespace) -> Namespace:
set_current_test_framework(pyproject_config["test_framework"])
if args.tests_root is None:
if is_js_ts_project:
if is_java_project:
# Try standard Maven/Gradle test directories
for test_dir in ["src/test/java", "test", "tests"]:
test_path = Path(args.module_root).parent / test_dir if "/" in test_dir else Path(test_dir)
if not test_path.is_absolute():
test_path = Path.cwd() / test_path
if test_path.is_dir():
args.tests_root = str(test_path)
break
if args.tests_root is None:
args.tests_root = str(Path.cwd() / "src" / "test" / "java")
elif is_js_ts_project:
# Try common JS test directories at project root first
for test_dir in ["test", "tests", "__tests__"]:
if Path(test_dir).is_dir():
@ -256,17 +276,6 @@ def process_pyproject_config(args: Namespace) -> Namespace:
# In such cases, the user should explicitly configure testsRoot in package.json
if args.tests_root is None:
args.tests_root = args.module_root
elif is_java_project:
# Try standard Maven/Gradle test directories
for test_dir in ["src/test/java", "test", "tests"]:
test_path = Path(args.module_root).parent / test_dir if "/" in test_dir else Path(test_dir)
if not test_path.is_absolute():
test_path = Path.cwd() / test_path
if test_path.is_dir():
args.tests_root = str(test_path)
break
if args.tests_root is None:
args.tests_root = str(Path.cwd() / "src" / "test" / "java")
else:
raise AssertionError("--tests-root must be specified")
assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory"
@ -327,7 +336,6 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path)
return current.resolve()
if (current / "build.gradle").exists() or (current / "build.gradle.kts").exists():
return current.resolve()
# Check for config file (pyproject.toml for Python, codeflash.toml for other languages)
if (current / "codeflash.toml").exists():
return current.resolve()
current = current.parent
@ -378,32 +386,52 @@ def _handle_show_config() -> None:
from codeflash.setup.detector import detect_project, has_existing_config
project_root = Path.cwd()
detected = detect_project(project_root)
config_exists, _ = has_existing_config(project_root)
# Check if config exists or is auto-detected
config_exists, config_file = has_existing_config(project_root)
status = "Saved config" if config_exists else "Auto-detected (not saved)"
if config_exists:
from codeflash.code_utils.config_parser import parse_config_file
console.print()
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
if config_exists and config_file:
console.print(f"[dim]Config file: {project_root / config_file}[/dim]")
console.print()
config, config_file_path = parse_config_file()
status = "Saved config"
table = Table(show_header=True, header_style="bold cyan")
table.add_column("Setting", style="dim")
table.add_column("Value")
console.print()
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
console.print(f"[dim]Config file: {config_file_path}[/dim]")
console.print()
table.add_row("Language", detected.language)
table.add_row("Project root", str(detected.project_root))
table.add_row("Module root", str(detected.module_root))
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
table.add_row("Test runner", detected.test_runner or "(not detected)")
table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)")
table.add_row(
"Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)"
)
table.add_row("Confidence", f"{detected.confidence:.0%}")
table = Table(show_header=True, header_style="bold cyan")
table.add_column("Setting", style="dim")
table.add_column("Value")
table.add_row("Project root", str(project_root))
table.add_row("Module root", config.get("module_root", "(not set)"))
table.add_row("Tests root", config.get("tests_root", "(not set)"))
table.add_row("Test runner", config.get("test_framework", config.get("pytest_cmd", "(not set)")))
table.add_row("Formatter", ", ".join(config["formatter_cmds"]) if config.get("formatter_cmds") else "(not set)")
ignore_paths = config.get("ignore_paths", [])
table.add_row("Ignore paths", ", ".join(str(p) for p in ignore_paths) if ignore_paths else "(none)")
else:
detected = detect_project(project_root)
status = "Auto-detected (not saved)"
console.print()
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
console.print()
table = Table(show_header=True, header_style="bold cyan")
table.add_column("Setting", style="dim")
table.add_column("Value")
table.add_row("Language", detected.language)
table.add_row("Project root", str(detected.project_root))
table.add_row("Module root", str(detected.module_root))
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
table.add_row("Test runner", detected.test_runner or "(not detected)")
table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)")
table.add_row(
"Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)"
)
table.add_row("Confidence", f"{detected.confidence:.0%}")
console.print(table)
console.print()
@ -436,7 +464,7 @@ def _handle_reset_config(confirm: bool = True) -> None:
console.print("[bold]This will remove Codeflash configuration from your project.[/bold]")
console.print()
config_file = {"python": "pyproject.toml", "java": "codeflash.toml"}.get(detected.language, "package.json")
config_file = "pyproject.toml" if detected.language == "python" else "package.json"
console.print(f" Config file: {project_root / config_file}")
console.print()

View file

@ -22,7 +22,7 @@ from rich.progress import (
from codeflash.cli_cmds.console_constants import SPINNER_TYPES
from codeflash.cli_cmds.logging_config import BARE_LOGGING_FORMAT
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.lsp.helpers import is_LSP_enabled, is_subagent_mode
from codeflash.lsp.lsp_logger import enhanced_log
from codeflash.lsp.lsp_message import LspCodeMessage, LspTextMessage
@ -35,42 +35,69 @@ if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import DependencyResolver, IndexResult
from codeflash.lsp.lsp_message import LspMessage
from codeflash.models.models import TestResults
DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG
console = Console(highlighter=NullHighlighter())
if is_LSP_enabled():
if is_LSP_enabled() or is_subagent_mode():
console.quiet = True
logging.basicConfig(
level=logging.INFO,
handlers=[
RichHandler(
rich_tracebacks=True,
markup=False,
highlighter=NullHighlighter(),
console=console,
show_path=False,
show_time=False,
)
],
format=BARE_LOGGING_FORMAT,
)
if is_subagent_mode():
import re
import sys
_lsp_prefix_re = re.compile(r"^(?:!?lsp,?|h[2-4]|loading)\|")
_subagent_drop_patterns = (
"Test log -",
"Test failed to load",
"Examining file ",
"Generated ",
"Add custom marker",
"Disabling all autouse",
"Reverting code and helpers",
)
class _AgentLogFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
record.msg = _lsp_prefix_re.sub("", str(record.msg))
msg = record.getMessage()
return not any(msg.startswith(p) for p in _subagent_drop_patterns)
_agent_handler = logging.StreamHandler(sys.stderr)
_agent_handler.addFilter(_AgentLogFilter())
logging.basicConfig(level=logging.INFO, handlers=[_agent_handler], format="%(levelname)s: %(message)s")
else:
logging.basicConfig(
level=logging.INFO,
handlers=[
RichHandler(
rich_tracebacks=True,
markup=False,
highlighter=NullHighlighter(),
console=console,
show_path=False,
show_time=False,
)
],
format=BARE_LOGGING_FORMAT,
)
logger = logging.getLogger("rich")
logging.getLogger("parso").setLevel(logging.WARNING)
# override the logger to reformat the messages for the lsp
for level in ("info", "debug", "warning", "error"):
real_fn = getattr(logger, level)
setattr(
logger,
level,
lambda msg, *args, _real_fn=real_fn, _level=level, **kwargs: enhanced_log(
msg, _real_fn, _level, *args, **kwargs
),
)
if not is_subagent_mode():
for level in ("info", "debug", "warning", "error"):
real_fn = getattr(logger, level)
setattr(
logger,
level,
lambda msg, *args, _real_fn=real_fn, _level=level, **kwargs: enhanced_log(
msg, _real_fn, _level, *args, **kwargs
),
)
class DummyTask:
@ -97,6 +124,8 @@ def paneled_text(
text: str, panel_args: dict[str, str | bool] | None = None, text_args: dict[str, str] | None = None
) -> None:
"""Print text in a panel."""
if is_subagent_mode():
return
from rich.panel import Panel
from rich.text import Text
@ -125,6 +154,8 @@ def code_print(
language: Programming language for syntax highlighting ('python', 'javascript', 'typescript')
"""
if is_subagent_mode():
return
if is_LSP_enabled():
lsp_log(
LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name, message_id=lsp_message_id)
@ -162,6 +193,10 @@ def progress_bar(
"""
global _progress_bar_active
if is_subagent_mode():
yield DummyTask().id
return
if is_LSP_enabled():
lsp_log(LspTextMessage(text=message, takes_time=True))
yield
@ -193,6 +228,10 @@ def progress_bar(
@contextmanager
def test_files_progress_bar(total: int, description: str) -> Generator[tuple[Progress, TaskID], None, None]:
"""Progress bar for test files."""
if is_subagent_mode():
yield DummyProgress(), DummyTask().id
return
if is_LSP_enabled():
lsp_log(LspTextMessage(text=description, takes_time=True))
dummy_progress = DummyProgress()
@ -226,6 +265,10 @@ def call_graph_live_display(
from rich.text import Text
from rich.tree import Tree
if is_subagent_mode():
yield lambda _: None
return
if is_LSP_enabled():
lsp_log(LspTextMessage(text="Building call graph", takes_time=True))
yield lambda _: None
@ -333,6 +376,9 @@ def call_graph_summary(call_graph: DependencyResolver, file_to_funcs: dict[Path,
if not total_functions:
return
if is_subagent_mode():
return
# Build the mapping expected by the dependency resolver
file_items = file_to_funcs.items()
mapping = {file_path: {func.qualified_name for func in funcs} for file_path, funcs in file_items}
@ -359,3 +405,92 @@ def call_graph_summary(call_graph: DependencyResolver, file_to_funcs: dict[Path,
return
console.print(Panel(summary, title="Call Graph Summary", border_style="cyan"))
def subagent_log_optimization_result(
function_name: str,
file_path: Path,
perf_improvement_line: str,
original_runtime_ns: int,
best_runtime_ns: int,
raw_explanation: str,
original_code: dict[Path, str],
new_code: dict[Path, str],
review: str,
test_results: TestResults,
) -> None:
import sys
from xml.sax.saxutils import escape
from codeflash.code_utils.code_utils import unified_diff_strings
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.models.test_type import TestType
diff_parts = []
for path in original_code:
old = original_code.get(path, "")
new = new_code.get(path, "")
if old != new:
diff = unified_diff_strings(old, new, fromfile=str(path), tofile=str(path))
if diff:
diff_parts.append(diff)
diff_str = "\n".join(diff_parts)
original_runtime = humanize_runtime(original_runtime_ns)
optimized_runtime = humanize_runtime(best_runtime_ns)
report = test_results.get_test_pass_fail_report_by_type()
verification_rows = []
for test_type in TestType:
if test_type is TestType.INIT_STATE_TEST:
continue
name = test_type.to_name()
if not name:
continue
passed = report[test_type]["passed"]
failed = report[test_type]["failed"]
if passed == 0 and failed == 0:
status = "None Found"
elif failed > 0:
status = f"{failed} Failed, {passed} Passed"
else:
status = f"{passed} Passed"
verification_rows.append(f' <test type="{escape(name)}" status="{escape(status)}"/>')
xml = [
"<codeflash-optimization>",
f" <function>{escape(function_name)}</function>",
f" <file>{escape(str(file_path))}</file>",
f" <performance>{escape(perf_improvement_line)}</performance>",
f" <original-runtime>{escape(original_runtime)}</original-runtime>",
f" <optimized-runtime>{escape(optimized_runtime)}</optimized-runtime>",
]
if review:
xml.append(f" <reviewer-confidence>{escape(review)}</reviewer-confidence>")
xml.append(f" <explanation>{escape(raw_explanation.strip())}</explanation>")
xml.append(" <verification>")
xml.extend(verification_rows)
xml.append(" </verification>")
if diff_str:
xml.append(f" <diff>{escape(diff_str)}</diff>")
for path in new_code:
if new_code[path] != original_code.get(path, ""):
xml.append(f' <optimized-code file="{escape(str(path))}">{escape(new_code[path])}</optimized-code>')
xml.append(" <action>")
xml.append(" First, present a short summary: function name, performance improvement, and explanation.")
xml.append(" If you are running inside an IDE (VS Code, Cursor, Windsurf, etc.):")
xml.append(" For each optimized-code element, use the Write tool to write the optimized code to the file.")
xml.append(" This will surface the diff in the IDE's native diff view for the user to accept or reject.")
xml.append(" If you are running in a terminal without an IDE connection:")
xml.append(" Present this optimization to the user using AskUserQuestion.")
xml.append(" Show the function name, performance improvement, and explanation as the question text.")
xml.append(" Provide two options: 'Apply' (write the optimized code to the file) and 'Reject' (do nothing).")
xml.append(
" Put the full diff in the 'Apply' option's markdown preview so the user can review the exact changes."
)
xml.append(" If the user chooses 'Apply', write the content from optimized-code to the corresponding file.")
xml.append(" </action>")
xml.append("</codeflash-optimization>")
sys.stdout.write("\n".join(xml) + "\n")

View file

@ -5,8 +5,18 @@ BARE_LOGGING_FORMAT = "%(message)s"
def set_level(level: int, *, echo_setting: bool = True) -> None:
import logging
import sys
import time
from codeflash.lsp.helpers import is_subagent_mode
if is_subagent_mode():
logging.basicConfig(
level=level, handlers=[logging.StreamHandler(sys.stderr)], format="%(levelname)s: %(message)s", force=True
)
logging.getLogger().setLevel(level)
return
from rich.highlighter import NullHighlighter
from rich.logging import RichHandler