mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
feat: sync CLI from main with Java additions
Adds subagent mode, config display, and main's CLI refactoring. Preserves omni-java's NullHighlighter, Java test root detection, and Java project root detection (pom.xml/build.gradle).
This commit is contained in:
parent
d518ad2d91
commit
2299d26ae5
3 changed files with 234 additions and 61 deletions
|
|
@ -130,9 +130,18 @@ def parse_args() -> Namespace:
|
|||
"--reset-config", action="store_true", help="Remove codeflash configuration from project config file."
|
||||
)
|
||||
parser.add_argument("-y", "--yes", action="store_true", help="Skip confirmation prompts (useful for CI/scripts).")
|
||||
parser.add_argument(
|
||||
"--subagent",
|
||||
action="store_true",
|
||||
help="Subagent mode: skip all interactive prompts with sensible defaults. Designed for AI agent integrations.",
|
||||
)
|
||||
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
sys.argv[:] = [sys.argv[0], *unknown_args]
|
||||
if args.subagent:
|
||||
args.yes = True
|
||||
args.no_pr = True
|
||||
args.worktree = True
|
||||
return process_and_validate_cmd_args(args)
|
||||
|
||||
|
||||
|
|
@ -237,7 +246,18 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
|||
set_current_test_framework(pyproject_config["test_framework"])
|
||||
|
||||
if args.tests_root is None:
|
||||
if is_js_ts_project:
|
||||
if is_java_project:
|
||||
# Try standard Maven/Gradle test directories
|
||||
for test_dir in ["src/test/java", "test", "tests"]:
|
||||
test_path = Path(args.module_root).parent / test_dir if "/" in test_dir else Path(test_dir)
|
||||
if not test_path.is_absolute():
|
||||
test_path = Path.cwd() / test_path
|
||||
if test_path.is_dir():
|
||||
args.tests_root = str(test_path)
|
||||
break
|
||||
if args.tests_root is None:
|
||||
args.tests_root = str(Path.cwd() / "src" / "test" / "java")
|
||||
elif is_js_ts_project:
|
||||
# Try common JS test directories at project root first
|
||||
for test_dir in ["test", "tests", "__tests__"]:
|
||||
if Path(test_dir).is_dir():
|
||||
|
|
@ -256,17 +276,6 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
|||
# In such cases, the user should explicitly configure testsRoot in package.json
|
||||
if args.tests_root is None:
|
||||
args.tests_root = args.module_root
|
||||
elif is_java_project:
|
||||
# Try standard Maven/Gradle test directories
|
||||
for test_dir in ["src/test/java", "test", "tests"]:
|
||||
test_path = Path(args.module_root).parent / test_dir if "/" in test_dir else Path(test_dir)
|
||||
if not test_path.is_absolute():
|
||||
test_path = Path.cwd() / test_path
|
||||
if test_path.is_dir():
|
||||
args.tests_root = str(test_path)
|
||||
break
|
||||
if args.tests_root is None:
|
||||
args.tests_root = str(Path.cwd() / "src" / "test" / "java")
|
||||
else:
|
||||
raise AssertionError("--tests-root must be specified")
|
||||
assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory"
|
||||
|
|
@ -327,7 +336,6 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path)
|
|||
return current.resolve()
|
||||
if (current / "build.gradle").exists() or (current / "build.gradle.kts").exists():
|
||||
return current.resolve()
|
||||
# Check for config file (pyproject.toml for Python, codeflash.toml for other languages)
|
||||
if (current / "codeflash.toml").exists():
|
||||
return current.resolve()
|
||||
current = current.parent
|
||||
|
|
@ -378,32 +386,52 @@ def _handle_show_config() -> None:
|
|||
from codeflash.setup.detector import detect_project, has_existing_config
|
||||
|
||||
project_root = Path.cwd()
|
||||
detected = detect_project(project_root)
|
||||
config_exists, _ = has_existing_config(project_root)
|
||||
|
||||
# Check if config exists or is auto-detected
|
||||
config_exists, config_file = has_existing_config(project_root)
|
||||
status = "Saved config" if config_exists else "Auto-detected (not saved)"
|
||||
if config_exists:
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
|
||||
if config_exists and config_file:
|
||||
console.print(f"[dim]Config file: {project_root / config_file}[/dim]")
|
||||
console.print()
|
||||
config, config_file_path = parse_config_file()
|
||||
status = "Saved config"
|
||||
|
||||
table = Table(show_header=True, header_style="bold cyan")
|
||||
table.add_column("Setting", style="dim")
|
||||
table.add_column("Value")
|
||||
console.print()
|
||||
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
|
||||
console.print(f"[dim]Config file: {config_file_path}[/dim]")
|
||||
console.print()
|
||||
|
||||
table.add_row("Language", detected.language)
|
||||
table.add_row("Project root", str(detected.project_root))
|
||||
table.add_row("Module root", str(detected.module_root))
|
||||
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
|
||||
table.add_row("Test runner", detected.test_runner or "(not detected)")
|
||||
table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)")
|
||||
table.add_row(
|
||||
"Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)"
|
||||
)
|
||||
table.add_row("Confidence", f"{detected.confidence:.0%}")
|
||||
table = Table(show_header=True, header_style="bold cyan")
|
||||
table.add_column("Setting", style="dim")
|
||||
table.add_column("Value")
|
||||
|
||||
table.add_row("Project root", str(project_root))
|
||||
table.add_row("Module root", config.get("module_root", "(not set)"))
|
||||
table.add_row("Tests root", config.get("tests_root", "(not set)"))
|
||||
table.add_row("Test runner", config.get("test_framework", config.get("pytest_cmd", "(not set)")))
|
||||
table.add_row("Formatter", ", ".join(config["formatter_cmds"]) if config.get("formatter_cmds") else "(not set)")
|
||||
ignore_paths = config.get("ignore_paths", [])
|
||||
table.add_row("Ignore paths", ", ".join(str(p) for p in ignore_paths) if ignore_paths else "(none)")
|
||||
else:
|
||||
detected = detect_project(project_root)
|
||||
status = "Auto-detected (not saved)"
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
|
||||
console.print()
|
||||
|
||||
table = Table(show_header=True, header_style="bold cyan")
|
||||
table.add_column("Setting", style="dim")
|
||||
table.add_column("Value")
|
||||
|
||||
table.add_row("Language", detected.language)
|
||||
table.add_row("Project root", str(detected.project_root))
|
||||
table.add_row("Module root", str(detected.module_root))
|
||||
table.add_row("Tests root", str(detected.tests_root) if detected.tests_root else "(not detected)")
|
||||
table.add_row("Test runner", detected.test_runner or "(not detected)")
|
||||
table.add_row("Formatter", ", ".join(detected.formatter_cmds) if detected.formatter_cmds else "(not detected)")
|
||||
table.add_row(
|
||||
"Ignore paths", ", ".join(str(p) for p in detected.ignore_paths) if detected.ignore_paths else "(none)"
|
||||
)
|
||||
table.add_row("Confidence", f"{detected.confidence:.0%}")
|
||||
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
|
@ -436,7 +464,7 @@ def _handle_reset_config(confirm: bool = True) -> None:
|
|||
console.print("[bold]This will remove Codeflash configuration from your project.[/bold]")
|
||||
console.print()
|
||||
|
||||
config_file = {"python": "pyproject.toml", "java": "codeflash.toml"}.get(detected.language, "package.json")
|
||||
config_file = "pyproject.toml" if detected.language == "python" else "package.json"
|
||||
console.print(f" Config file: {project_root / config_file}")
|
||||
console.print()
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from rich.progress import (
|
|||
|
||||
from codeflash.cli_cmds.console_constants import SPINNER_TYPES
|
||||
from codeflash.cli_cmds.logging_config import BARE_LOGGING_FORMAT
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.lsp.helpers import is_LSP_enabled, is_subagent_mode
|
||||
from codeflash.lsp.lsp_logger import enhanced_log
|
||||
from codeflash.lsp.lsp_message import LspCodeMessage, LspTextMessage
|
||||
|
||||
|
|
@ -35,42 +35,69 @@ if TYPE_CHECKING:
|
|||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import DependencyResolver, IndexResult
|
||||
from codeflash.lsp.lsp_message import LspMessage
|
||||
from codeflash.models.models import TestResults
|
||||
|
||||
DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG
|
||||
|
||||
console = Console(highlighter=NullHighlighter())
|
||||
|
||||
if is_LSP_enabled():
|
||||
if is_LSP_enabled() or is_subagent_mode():
|
||||
console.quiet = True
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
handlers=[
|
||||
RichHandler(
|
||||
rich_tracebacks=True,
|
||||
markup=False,
|
||||
highlighter=NullHighlighter(),
|
||||
console=console,
|
||||
show_path=False,
|
||||
show_time=False,
|
||||
)
|
||||
],
|
||||
format=BARE_LOGGING_FORMAT,
|
||||
)
|
||||
if is_subagent_mode():
|
||||
import re
|
||||
import sys
|
||||
|
||||
_lsp_prefix_re = re.compile(r"^(?:!?lsp,?|h[2-4]|loading)\|")
|
||||
_subagent_drop_patterns = (
|
||||
"Test log -",
|
||||
"Test failed to load",
|
||||
"Examining file ",
|
||||
"Generated ",
|
||||
"Add custom marker",
|
||||
"Disabling all autouse",
|
||||
"Reverting code and helpers",
|
||||
)
|
||||
|
||||
class _AgentLogFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
record.msg = _lsp_prefix_re.sub("", str(record.msg))
|
||||
msg = record.getMessage()
|
||||
return not any(msg.startswith(p) for p in _subagent_drop_patterns)
|
||||
|
||||
_agent_handler = logging.StreamHandler(sys.stderr)
|
||||
_agent_handler.addFilter(_AgentLogFilter())
|
||||
logging.basicConfig(level=logging.INFO, handlers=[_agent_handler], format="%(levelname)s: %(message)s")
|
||||
else:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
handlers=[
|
||||
RichHandler(
|
||||
rich_tracebacks=True,
|
||||
markup=False,
|
||||
highlighter=NullHighlighter(),
|
||||
console=console,
|
||||
show_path=False,
|
||||
show_time=False,
|
||||
)
|
||||
],
|
||||
format=BARE_LOGGING_FORMAT,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("rich")
|
||||
logging.getLogger("parso").setLevel(logging.WARNING)
|
||||
|
||||
# override the logger to reformat the messages for the lsp
|
||||
for level in ("info", "debug", "warning", "error"):
|
||||
real_fn = getattr(logger, level)
|
||||
setattr(
|
||||
logger,
|
||||
level,
|
||||
lambda msg, *args, _real_fn=real_fn, _level=level, **kwargs: enhanced_log(
|
||||
msg, _real_fn, _level, *args, **kwargs
|
||||
),
|
||||
)
|
||||
if not is_subagent_mode():
|
||||
for level in ("info", "debug", "warning", "error"):
|
||||
real_fn = getattr(logger, level)
|
||||
setattr(
|
||||
logger,
|
||||
level,
|
||||
lambda msg, *args, _real_fn=real_fn, _level=level, **kwargs: enhanced_log(
|
||||
msg, _real_fn, _level, *args, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DummyTask:
|
||||
|
|
@ -97,6 +124,8 @@ def paneled_text(
|
|||
text: str, panel_args: dict[str, str | bool] | None = None, text_args: dict[str, str] | None = None
|
||||
) -> None:
|
||||
"""Print text in a panel."""
|
||||
if is_subagent_mode():
|
||||
return
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
|
||||
|
|
@ -125,6 +154,8 @@ def code_print(
|
|||
language: Programming language for syntax highlighting ('python', 'javascript', 'typescript')
|
||||
|
||||
"""
|
||||
if is_subagent_mode():
|
||||
return
|
||||
if is_LSP_enabled():
|
||||
lsp_log(
|
||||
LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name, message_id=lsp_message_id)
|
||||
|
|
@ -162,6 +193,10 @@ def progress_bar(
|
|||
"""
|
||||
global _progress_bar_active
|
||||
|
||||
if is_subagent_mode():
|
||||
yield DummyTask().id
|
||||
return
|
||||
|
||||
if is_LSP_enabled():
|
||||
lsp_log(LspTextMessage(text=message, takes_time=True))
|
||||
yield
|
||||
|
|
@ -193,6 +228,10 @@ def progress_bar(
|
|||
@contextmanager
|
||||
def test_files_progress_bar(total: int, description: str) -> Generator[tuple[Progress, TaskID], None, None]:
|
||||
"""Progress bar for test files."""
|
||||
if is_subagent_mode():
|
||||
yield DummyProgress(), DummyTask().id
|
||||
return
|
||||
|
||||
if is_LSP_enabled():
|
||||
lsp_log(LspTextMessage(text=description, takes_time=True))
|
||||
dummy_progress = DummyProgress()
|
||||
|
|
@ -226,6 +265,10 @@ def call_graph_live_display(
|
|||
from rich.text import Text
|
||||
from rich.tree import Tree
|
||||
|
||||
if is_subagent_mode():
|
||||
yield lambda _: None
|
||||
return
|
||||
|
||||
if is_LSP_enabled():
|
||||
lsp_log(LspTextMessage(text="Building call graph", takes_time=True))
|
||||
yield lambda _: None
|
||||
|
|
@ -333,6 +376,9 @@ def call_graph_summary(call_graph: DependencyResolver, file_to_funcs: dict[Path,
|
|||
if not total_functions:
|
||||
return
|
||||
|
||||
if is_subagent_mode():
|
||||
return
|
||||
|
||||
# Build the mapping expected by the dependency resolver
|
||||
file_items = file_to_funcs.items()
|
||||
mapping = {file_path: {func.qualified_name for func in funcs} for file_path, funcs in file_items}
|
||||
|
|
@ -359,3 +405,92 @@ def call_graph_summary(call_graph: DependencyResolver, file_to_funcs: dict[Path,
|
|||
return
|
||||
|
||||
console.print(Panel(summary, title="Call Graph Summary", border_style="cyan"))
|
||||
|
||||
|
||||
def subagent_log_optimization_result(
|
||||
function_name: str,
|
||||
file_path: Path,
|
||||
perf_improvement_line: str,
|
||||
original_runtime_ns: int,
|
||||
best_runtime_ns: int,
|
||||
raw_explanation: str,
|
||||
original_code: dict[Path, str],
|
||||
new_code: dict[Path, str],
|
||||
review: str,
|
||||
test_results: TestResults,
|
||||
) -> None:
|
||||
import sys
|
||||
from xml.sax.saxutils import escape
|
||||
|
||||
from codeflash.code_utils.code_utils import unified_diff_strings
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.models.test_type import TestType
|
||||
|
||||
diff_parts = []
|
||||
for path in original_code:
|
||||
old = original_code.get(path, "")
|
||||
new = new_code.get(path, "")
|
||||
if old != new:
|
||||
diff = unified_diff_strings(old, new, fromfile=str(path), tofile=str(path))
|
||||
if diff:
|
||||
diff_parts.append(diff)
|
||||
|
||||
diff_str = "\n".join(diff_parts)
|
||||
|
||||
original_runtime = humanize_runtime(original_runtime_ns)
|
||||
optimized_runtime = humanize_runtime(best_runtime_ns)
|
||||
|
||||
report = test_results.get_test_pass_fail_report_by_type()
|
||||
verification_rows = []
|
||||
for test_type in TestType:
|
||||
if test_type is TestType.INIT_STATE_TEST:
|
||||
continue
|
||||
name = test_type.to_name()
|
||||
if not name:
|
||||
continue
|
||||
passed = report[test_type]["passed"]
|
||||
failed = report[test_type]["failed"]
|
||||
if passed == 0 and failed == 0:
|
||||
status = "None Found"
|
||||
elif failed > 0:
|
||||
status = f"{failed} Failed, {passed} Passed"
|
||||
else:
|
||||
status = f"{passed} Passed"
|
||||
verification_rows.append(f' <test type="{escape(name)}" status="{escape(status)}"/>')
|
||||
|
||||
xml = [
|
||||
"<codeflash-optimization>",
|
||||
f" <function>{escape(function_name)}</function>",
|
||||
f" <file>{escape(str(file_path))}</file>",
|
||||
f" <performance>{escape(perf_improvement_line)}</performance>",
|
||||
f" <original-runtime>{escape(original_runtime)}</original-runtime>",
|
||||
f" <optimized-runtime>{escape(optimized_runtime)}</optimized-runtime>",
|
||||
]
|
||||
if review:
|
||||
xml.append(f" <reviewer-confidence>{escape(review)}</reviewer-confidence>")
|
||||
xml.append(f" <explanation>{escape(raw_explanation.strip())}</explanation>")
|
||||
xml.append(" <verification>")
|
||||
xml.extend(verification_rows)
|
||||
xml.append(" </verification>")
|
||||
if diff_str:
|
||||
xml.append(f" <diff>{escape(diff_str)}</diff>")
|
||||
for path in new_code:
|
||||
if new_code[path] != original_code.get(path, ""):
|
||||
xml.append(f' <optimized-code file="{escape(str(path))}">{escape(new_code[path])}</optimized-code>')
|
||||
xml.append(" <action>")
|
||||
xml.append(" First, present a short summary: function name, performance improvement, and explanation.")
|
||||
xml.append(" If you are running inside an IDE (VS Code, Cursor, Windsurf, etc.):")
|
||||
xml.append(" For each optimized-code element, use the Write tool to write the optimized code to the file.")
|
||||
xml.append(" This will surface the diff in the IDE's native diff view for the user to accept or reject.")
|
||||
xml.append(" If you are running in a terminal without an IDE connection:")
|
||||
xml.append(" Present this optimization to the user using AskUserQuestion.")
|
||||
xml.append(" Show the function name, performance improvement, and explanation as the question text.")
|
||||
xml.append(" Provide two options: 'Apply' (write the optimized code to the file) and 'Reject' (do nothing).")
|
||||
xml.append(
|
||||
" Put the full diff in the 'Apply' option's markdown preview so the user can review the exact changes."
|
||||
)
|
||||
xml.append(" If the user chooses 'Apply', write the content from optimized-code to the corresponding file.")
|
||||
xml.append(" </action>")
|
||||
xml.append("</codeflash-optimization>")
|
||||
|
||||
sys.stdout.write("\n".join(xml) + "\n")
|
||||
|
|
|
|||
|
|
@ -5,8 +5,18 @@ BARE_LOGGING_FORMAT = "%(message)s"
|
|||
|
||||
def set_level(level: int, *, echo_setting: bool = True) -> None:
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
|
||||
from codeflash.lsp.helpers import is_subagent_mode
|
||||
|
||||
if is_subagent_mode():
|
||||
logging.basicConfig(
|
||||
level=level, handlers=[logging.StreamHandler(sys.stderr)], format="%(levelname)s: %(message)s", force=True
|
||||
)
|
||||
logging.getLogger().setLevel(level)
|
||||
return
|
||||
|
||||
from rich.highlighter import NullHighlighter
|
||||
from rich.logging import RichHandler
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue