diff options
Diffstat (limited to 'packages/cli/src')
| -rw-r--r-- | packages/cli/src/services/BuiltinCommandLoader.test.ts | 11 | ||||
| -rw-r--r-- | packages/cli/src/services/BuiltinCommandLoader.ts | 2 | ||||
| -rw-r--r-- | packages/cli/src/services/McpPromptLoader.ts | 231 | ||||
| -rw-r--r-- | packages/cli/src/ui/App.test.tsx | 1 | ||||
| -rw-r--r-- | packages/cli/src/ui/commands/mcpCommand.test.ts | 18 | ||||
| -rw-r--r-- | packages/cli/src/ui/commands/mcpCommand.ts | 65 | ||||
| -rw-r--r-- | packages/cli/src/ui/commands/types.ts | 1 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/slashCommandProcessor.test.ts | 78 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/slashCommandProcessor.ts | 155 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useCompletion.test.ts | 6 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useCompletion.ts | 11 |
11 files changed, 489 insertions, 90 deletions
diff --git a/packages/cli/src/services/BuiltinCommandLoader.test.ts b/packages/cli/src/services/BuiltinCommandLoader.test.ts index 0e64b1ac..cd449dd8 100644 --- a/packages/cli/src/services/BuiltinCommandLoader.test.ts +++ b/packages/cli/src/services/BuiltinCommandLoader.test.ts @@ -40,13 +40,19 @@ vi.mock('../ui/commands/extensionsCommand.js', () => ({ extensionsCommand: {}, })); vi.mock('../ui/commands/helpCommand.js', () => ({ helpCommand: {} })); -vi.mock('../ui/commands/mcpCommand.js', () => ({ mcpCommand: {} })); vi.mock('../ui/commands/memoryCommand.js', () => ({ memoryCommand: {} })); vi.mock('../ui/commands/privacyCommand.js', () => ({ privacyCommand: {} })); vi.mock('../ui/commands/quitCommand.js', () => ({ quitCommand: {} })); vi.mock('../ui/commands/statsCommand.js', () => ({ statsCommand: {} })); vi.mock('../ui/commands/themeCommand.js', () => ({ themeCommand: {} })); vi.mock('../ui/commands/toolsCommand.js', () => ({ toolsCommand: {} })); +vi.mock('../ui/commands/mcpCommand.js', () => ({ + mcpCommand: { + name: 'mcp', + description: 'MCP command', + kind: 'BUILT_IN', + }, +})); describe('BuiltinCommandLoader', () => { let mockConfig: Config; @@ -114,5 +120,8 @@ describe('BuiltinCommandLoader', () => { const ideCmd = commands.find((c) => c.name === 'ide'); expect(ideCmd).toBeDefined(); + + const mcpCmd = commands.find((c) => c.name === 'mcp'); + expect(mcpCmd).toBeDefined(); }); }); diff --git a/packages/cli/src/services/BuiltinCommandLoader.ts b/packages/cli/src/services/BuiltinCommandLoader.ts index 259c6013..58adf5cb 100644 --- a/packages/cli/src/services/BuiltinCommandLoader.ts +++ b/packages/cli/src/services/BuiltinCommandLoader.ts @@ -58,9 +58,9 @@ export class BuiltinCommandLoader implements ICommandLoader { extensionsCommand, helpCommand, ideCommand(this.config), - mcpCommand, memoryCommand, privacyCommand, + mcpCommand, quitCommand, restoreCommand(this.config), statsCommand, diff --git a/packages/cli/src/services/McpPromptLoader.ts b/packages/cli/src/services/McpPromptLoader.ts new file mode 100644 index 00000000..e912fb3e --- /dev/null +++ b/packages/cli/src/services/McpPromptLoader.ts @@ -0,0 +1,231 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + Config, + getErrorMessage, + getMCPServerPrompts, +} from '@google/gemini-cli-core'; +import { + CommandContext, + CommandKind, + SlashCommand, + SlashCommandActionReturn, +} from '../ui/commands/types.js'; +import { ICommandLoader } from './types.js'; +import { PromptArgument } from '@modelcontextprotocol/sdk/types.js'; + +/** + * Discovers and loads executable slash commands from prompts exposed by + * Model-Context-Protocol (MCP) servers. + */ +export class McpPromptLoader implements ICommandLoader { + constructor(private readonly config: Config | null) {} + + /** + * Loads all available prompts from all configured MCP servers and adapts + * them into executable SlashCommand objects. + * + * @param _signal An AbortSignal (unused for this synchronous loader). + * @returns A promise that resolves to an array of loaded SlashCommands. + */ + loadCommands(_signal: AbortSignal): Promise<SlashCommand[]> { + const promptCommands: SlashCommand[] = []; + if (!this.config) { + return Promise.resolve([]); + } + const mcpServers = this.config.getMcpServers() || {}; + for (const serverName in mcpServers) { + const prompts = getMCPServerPrompts(this.config, serverName) || []; + for (const prompt of prompts) { + const commandName = `${prompt.name}`; + const newPromptCommand: SlashCommand = { + name: commandName, + description: prompt.description || `Invoke prompt ${prompt.name}`, + kind: CommandKind.MCP_PROMPT, + subCommands: [ + { + name: 'help', + description: 'Show help for this prompt', + kind: CommandKind.MCP_PROMPT, + action: async (): Promise<SlashCommandActionReturn> => { + if (!prompt.arguments || prompt.arguments.length === 0) { + return { + type: 'message', + messageType: 'info', + content: `Prompt "${prompt.name}" has no arguments.`, + }; + } + + let helpMessage = `Arguments for "${prompt.name}":\n\n`; + if (prompt.arguments && prompt.arguments.length > 0) { + helpMessage += `You can provide arguments by name (e.g., --argName="value") or by position.\n\n`; + helpMessage += `e.g., ${prompt.name} ${prompt.arguments?.map((_) => `"foo"`)} is equivalent to ${prompt.name} ${prompt.arguments?.map((arg) => `--${arg.name}="foo"`)}\n\n`; + } + for (const arg of prompt.arguments) { + helpMessage += ` --${arg.name}\n`; + if (arg.description) { + helpMessage += ` ${arg.description}\n`; + } + helpMessage += ` (required: ${ + arg.required ? 'yes' : 'no' + })\n\n`; + } + return { + type: 'message', + messageType: 'info', + content: helpMessage, + }; + }, + }, + ], + action: async ( + context: CommandContext, + args: string, + ): Promise<SlashCommandActionReturn> => { + if (!this.config) { + return { + type: 'message', + messageType: 'error', + content: 'Config not loaded.', + }; + } + + const promptInputs = this.parseArgs(args, prompt.arguments); + if (promptInputs instanceof Error) { + return { + type: 'message', + messageType: 'error', + content: promptInputs.message, + }; + } + + try { + const mcpServers = this.config.getMcpServers() || {}; + const mcpServerConfig = mcpServers[serverName]; + if (!mcpServerConfig) { + return { + type: 'message', + messageType: 'error', + content: `MCP server config not found for '${serverName}'.`, + }; + } + const result = await prompt.invoke(promptInputs); + + if (result.error) { + return { + type: 'message', + messageType: 'error', + content: `Error invoking prompt: ${result.error}`, + }; + } + + if (!result.messages?.[0]?.content?.text) { + return { + type: 'message', + messageType: 'error', + content: + 'Received an empty or invalid prompt response from the server.', + }; + } + + return { + type: 'submit_prompt', + content: JSON.stringify(result.messages[0].content.text), + }; + } catch (error) { + return { + type: 'message', + messageType: 'error', + content: `Error: ${getErrorMessage(error)}`, + }; + } + }, + completion: async (_: CommandContext, partialArg: string) => { + if (!prompt || !prompt.arguments) { + return []; + } + + const suggestions: string[] = []; + const usedArgNames = new Set( + (partialArg.match(/--([^=]+)/g) || []).map((s) => s.substring(2)), + ); + + for (const arg of prompt.arguments) { + if (!usedArgNames.has(arg.name)) { + suggestions.push(`--${arg.name}=""`); + } + } + + return suggestions; + }, + }; + promptCommands.push(newPromptCommand); + } + } + return Promise.resolve(promptCommands); + } + + private parseArgs( + userArgs: string, + promptArgs: PromptArgument[] | undefined, + ): Record<string, unknown> | Error { + const argValues: { [key: string]: string } = {}; + const promptInputs: Record<string, unknown> = {}; + + // arg parsing: --key="value" or --key=value + const namedArgRegex = /--([^=]+)=(?:"((?:\\.|[^"\\])*)"|([^ ]*))/g; + let match; + const remainingArgs: string[] = []; + let lastIndex = 0; + + while ((match = namedArgRegex.exec(userArgs)) !== null) { + const key = match[1]; + const value = match[2] ?? match[3]; // Quoted or unquoted value + argValues[key] = value; + // Capture text between matches as potential positional args + if (match.index > lastIndex) { + remainingArgs.push(userArgs.substring(lastIndex, match.index).trim()); + } + lastIndex = namedArgRegex.lastIndex; + } + + // Capture any remaining text after the last named arg + if (lastIndex < userArgs.length) { + remainingArgs.push(userArgs.substring(lastIndex).trim()); + } + + const positionalArgs = remainingArgs.join(' ').split(/ +/); + + if (!promptArgs) { + return promptInputs; + } + for (const arg of promptArgs) { + if (argValues[arg.name]) { + promptInputs[arg.name] = argValues[arg.name]; + } + } + + const unfilledArgs = promptArgs.filter( + (arg) => arg.required && !promptInputs[arg.name], + ); + + const missingArgs: string[] = []; + for (let i = 0; i < unfilledArgs.length; i++) { + if (positionalArgs.length > i && positionalArgs[i]) { + promptInputs[unfilledArgs[i].name] = positionalArgs[i]; + } else { + missingArgs.push(unfilledArgs[i].name); + } + } + + if (missingArgs.length > 0) { + const missingArgNames = missingArgs.map((name) => `--${name}`).join(', '); + return new Error(`Missing required argument(s): ${missingArgNames}`); + } + return promptInputs; + } +} diff --git a/packages/cli/src/ui/App.test.tsx b/packages/cli/src/ui/App.test.tsx index 56093562..903f4b66 100644 --- a/packages/cli/src/ui/App.test.tsx +++ b/packages/cli/src/ui/App.test.tsx @@ -125,6 +125,7 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => { getToolCallCommand: vi.fn(() => opts.toolCallCommand), getMcpServerCommand: vi.fn(() => opts.mcpServerCommand), getMcpServers: vi.fn(() => opts.mcpServers), + getPromptRegistry: vi.fn(), getExtensions: vi.fn(() => []), getBlockedMcpServers: vi.fn(() => []), getUserAgent: vi.fn(() => opts.userAgent || 'test-agent'), diff --git a/packages/cli/src/ui/commands/mcpCommand.test.ts b/packages/cli/src/ui/commands/mcpCommand.test.ts index 2b8753a0..afa71ba5 100644 --- a/packages/cli/src/ui/commands/mcpCommand.test.ts +++ b/packages/cli/src/ui/commands/mcpCommand.test.ts @@ -71,6 +71,7 @@ describe('mcpCommand', () => { getToolRegistry: ReturnType<typeof vi.fn>; getMcpServers: ReturnType<typeof vi.fn>; getBlockedMcpServers: ReturnType<typeof vi.fn>; + getPromptRegistry: ReturnType<typeof vi.fn>; }; beforeEach(() => { @@ -92,6 +93,10 @@ describe('mcpCommand', () => { }), getMcpServers: vi.fn().mockReturnValue({}), getBlockedMcpServers: vi.fn().mockReturnValue([]), + getPromptRegistry: vi.fn().mockResolvedValue({ + getAllPrompts: vi.fn().mockReturnValue([]), + getPromptsByServer: vi.fn().mockReturnValue([]), + }), }; mockContext = createMockCommandContext({ @@ -223,7 +228,7 @@ describe('mcpCommand', () => { // Server 2 - Connected expect(message).toContain( - '🟢 \u001b[1mserver2\u001b[0m - Ready (1 tools)', + '🟢 \u001b[1mserver2\u001b[0m - Ready (1 tool)', ); expect(message).toContain('server2_tool1'); @@ -365,13 +370,13 @@ describe('mcpCommand', () => { if (isMessageAction(result)) { const message = result.content; expect(message).toContain( - '🟢 \u001b[1mserver1\u001b[0m - Ready (1 tools)', + '🟢 \u001b[1mserver1\u001b[0m - Ready (1 tool)', ); expect(message).toContain('\u001b[36mserver1_tool1\u001b[0m'); expect(message).toContain( '🔴 \u001b[1mserver2\u001b[0m - Disconnected (0 tools cached)', ); - expect(message).toContain('No tools available'); + expect(message).toContain('No tools or prompts available'); } }); @@ -421,10 +426,10 @@ describe('mcpCommand', () => { // Check server statuses expect(message).toContain( - '🟢 \u001b[1mserver1\u001b[0m - Ready (1 tools)', + '🟢 \u001b[1mserver1\u001b[0m - Ready (1 tool)', ); expect(message).toContain( - '🔄 \u001b[1mserver2\u001b[0m - Starting... (first startup may take longer) (tools will appear when ready)', + '🔄 \u001b[1mserver2\u001b[0m - Starting... (first startup may take longer) (tools and prompts will appear when ready)', ); } }); @@ -994,6 +999,9 @@ describe('mcpCommand', () => { getBlockedMcpServers: vi.fn().mockReturnValue([]), getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry), getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient), + getPromptRegistry: vi.fn().mockResolvedValue({ + getPromptsByServer: vi.fn().mockReturnValue([]), + }), }, }, }); diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts index 5467b994..709053b6 100644 --- a/packages/cli/src/ui/commands/mcpCommand.ts +++ b/packages/cli/src/ui/commands/mcpCommand.ts @@ -12,6 +12,7 @@ import { MessageActionReturn, } from './types.js'; import { + DiscoveredMCPPrompt, DiscoveredMCPTool, getMCPDiscoveryState, getMCPServerStatus, @@ -101,6 +102,8 @@ const getMcpStatus = async ( (tool) => tool instanceof DiscoveredMCPTool && tool.serverName === serverName, ) as DiscoveredMCPTool[]; + const promptRegistry = await config.getPromptRegistry(); + const serverPrompts = promptRegistry.getPromptsByServer(serverName) || []; const status = getMCPServerStatus(serverName); @@ -160,9 +163,26 @@ const getMcpStatus = async ( // Add tool count with conditional messaging if (status === MCPServerStatus.CONNECTED) { - message += ` (${serverTools.length} tools)`; + const parts = []; + if (serverTools.length > 0) { + parts.push( + `${serverTools.length} ${serverTools.length === 1 ? 'tool' : 'tools'}`, + ); + } + if (serverPrompts.length > 0) { + parts.push( + `${serverPrompts.length} ${ + serverPrompts.length === 1 ? 'prompt' : 'prompts' + }`, + ); + } + if (parts.length > 0) { + message += ` (${parts.join(', ')})`; + } else { + message += ` (0 tools)`; + } } else if (status === MCPServerStatus.CONNECTING) { - message += ` (tools will appear when ready)`; + message += ` (tools and prompts will appear when ready)`; } else { message += ` (${serverTools.length} tools cached)`; } @@ -186,6 +206,7 @@ const getMcpStatus = async ( message += RESET_COLOR; if (serverTools.length > 0) { + message += ` ${COLOR_CYAN}Tools:${RESET_COLOR}\n`; serverTools.forEach((tool) => { if (showDescriptions && tool.description) { // Format tool name in cyan using simple ANSI cyan color @@ -222,12 +243,41 @@ const getMcpStatus = async ( } } }); - } else { + } + if (serverPrompts.length > 0) { + if (serverTools.length > 0) { + message += '\n'; + } + message += ` ${COLOR_CYAN}Prompts:${RESET_COLOR}\n`; + serverPrompts.forEach((prompt: DiscoveredMCPPrompt) => { + if (showDescriptions && prompt.description) { + message += ` - ${COLOR_CYAN}${prompt.name}${RESET_COLOR}`; + const descLines = prompt.description.trim().split('\n'); + if (descLines) { + message += ':\n'; + for (const descLine of descLines) { + message += ` ${COLOR_GREEN}${descLine}${RESET_COLOR}\n`; + } + } else { + message += '\n'; + } + } else { + message += ` - ${COLOR_CYAN}${prompt.name}${RESET_COLOR}\n`; + } + }); + } + + if (serverTools.length === 0 && serverPrompts.length === 0) { + message += ' No tools or prompts available\n'; + } else if (serverTools.length === 0) { message += ' No tools available'; if (status === MCPServerStatus.DISCONNECTED && needsAuthHint) { message += ` ${COLOR_GREY}(type: "/mcp auth ${serverName}" to authenticate this server)${RESET_COLOR}`; } message += '\n'; + } else if (status === MCPServerStatus.DISCONNECTED && needsAuthHint) { + // This case is for when serverTools.length > 0 + message += ` ${COLOR_GREY}(type: "/mcp auth ${serverName}" to authenticate this server)${RESET_COLOR}\n`; } message += '\n'; } @@ -328,11 +378,10 @@ const authCommand: SlashCommand = { // Import dynamically to avoid circular dependencies const { MCPOAuthProvider } = await import('@google/gemini-cli-core'); - // Create OAuth config for authentication (will be discovered automatically) - const oauthConfig = server.oauth || { - authorizationUrl: '', // Will be discovered automatically - tokenUrl: '', // Will be discovered automatically - }; + let oauthConfig = server.oauth; + if (!oauthConfig) { + oauthConfig = { enabled: false }; + } // Pass the MCP server URL for OAuth discovery const mcpServerUrl = server.httpUrl || server.url; diff --git a/packages/cli/src/ui/commands/types.ts b/packages/cli/src/ui/commands/types.ts index 9a1088fd..1684677c 100644 --- a/packages/cli/src/ui/commands/types.ts +++ b/packages/cli/src/ui/commands/types.ts @@ -128,6 +128,7 @@ export type SlashCommandActionReturn = export enum CommandKind { BUILT_IN = 'built-in', FILE = 'file', + MCP_PROMPT = 'mcp-prompt', } // The standardized contract for any command in the system. diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts index 84eeb033..d308af46 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts @@ -28,6 +28,13 @@ vi.mock('../../services/FileCommandLoader.js', () => ({ })), })); +const mockMcpLoadCommands = vi.fn(); +vi.mock('../../services/McpPromptLoader.js', () => ({ + McpPromptLoader: vi.fn().mockImplementation(() => ({ + loadCommands: mockMcpLoadCommands, + })), +})); + vi.mock('../contexts/SessionContext.js', () => ({ useSessionStats: vi.fn(() => ({ stats: {} })), })); @@ -41,6 +48,7 @@ import { LoadedSettings } from '../../config/settings.js'; import { MessageType } from '../types.js'; import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js'; import { FileCommandLoader } from '../../services/FileCommandLoader.js'; +import { McpPromptLoader } from '../../services/McpPromptLoader.js'; const createTestCommand = ( overrides: Partial<SlashCommand>, @@ -75,14 +83,17 @@ describe('useSlashCommandProcessor', () => { (vi.mocked(BuiltinCommandLoader) as Mock).mockClear(); mockBuiltinLoadCommands.mockResolvedValue([]); mockFileLoadCommands.mockResolvedValue([]); + mockMcpLoadCommands.mockResolvedValue([]); }); const setupProcessorHook = ( builtinCommands: SlashCommand[] = [], fileCommands: SlashCommand[] = [], + mcpCommands: SlashCommand[] = [], ) => { mockBuiltinLoadCommands.mockResolvedValue(Object.freeze(builtinCommands)); mockFileLoadCommands.mockResolvedValue(Object.freeze(fileCommands)); + mockMcpLoadCommands.mockResolvedValue(Object.freeze(mcpCommands)); const { result } = renderHook(() => useSlashCommandProcessor( @@ -111,6 +122,7 @@ describe('useSlashCommandProcessor', () => { setupProcessorHook(); expect(BuiltinCommandLoader).toHaveBeenCalledWith(mockConfig); expect(FileCommandLoader).toHaveBeenCalledWith(mockConfig); + expect(McpPromptLoader).toHaveBeenCalledWith(mockConfig); }); it('should call loadCommands and populate state after mounting', async () => { @@ -124,6 +136,7 @@ describe('useSlashCommandProcessor', () => { expect(result.current.slashCommands[0]?.name).toBe('test'); expect(mockBuiltinLoadCommands).toHaveBeenCalledTimes(1); expect(mockFileLoadCommands).toHaveBeenCalledTimes(1); + expect(mockMcpLoadCommands).toHaveBeenCalledTimes(1); }); it('should provide an immutable array of commands to consumers', async () => { @@ -369,6 +382,38 @@ describe('useSlashCommandProcessor', () => { expect.any(Number), ); }); + + it('should handle "submit_prompt" action returned from a mcp-based command', async () => { + const mcpCommand = createTestCommand( + { + name: 'mcpcmd', + description: 'A command from mcp', + action: async () => ({ + type: 'submit_prompt', + content: 'The actual prompt from the mcp command.', + }), + }, + CommandKind.MCP_PROMPT, + ); + + const result = setupProcessorHook([], [], [mcpCommand]); + await waitFor(() => expect(result.current.slashCommands).toHaveLength(1)); + + let actionResult; + await act(async () => { + actionResult = await result.current.handleSlashCommand('/mcpcmd'); + }); + + expect(actionResult).toEqual({ + type: 'submit_prompt', + content: 'The actual prompt from the mcp command.', + }); + + expect(mockAddItem).toHaveBeenCalledWith( + { type: MessageType.USER, text: '/mcpcmd' }, + expect.any(Number), + ); + }); }); describe('Command Parsing and Matching', () => { @@ -441,6 +486,39 @@ describe('useSlashCommandProcessor', () => { }); describe('Command Precedence', () => { + it('should override mcp-based commands with file-based commands of the same name', async () => { + const mcpAction = vi.fn(); + const fileAction = vi.fn(); + + const mcpCommand = createTestCommand( + { + name: 'override', + description: 'mcp', + action: mcpAction, + }, + CommandKind.MCP_PROMPT, + ); + const fileCommand = createTestCommand( + { name: 'override', description: 'file', action: fileAction }, + CommandKind.FILE, + ); + + const result = setupProcessorHook([], [fileCommand], [mcpCommand]); + + await waitFor(() => { + // The service should only return one command with the name 'override' + expect(result.current.slashCommands).toHaveLength(1); + }); + + await act(async () => { + await result.current.handleSlashCommand('/override'); + }); + + // Only the file-based command's action should be called. + expect(fileAction).toHaveBeenCalledTimes(1); + expect(mcpAction).not.toHaveBeenCalled(); + }); + it('should prioritize a command with a primary name over a command with a matching alias', async () => { const quitAction = vi.fn(); const exitAction = vi.fn(); diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index fa2b0b12..9e9dc21c 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -23,6 +23,7 @@ import { type CommandContext, type SlashCommand } from '../commands/types.js'; import { CommandService } from '../../services/CommandService.js'; import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js'; import { FileCommandLoader } from '../../services/FileCommandLoader.js'; +import { McpPromptLoader } from '../../services/McpPromptLoader.js'; /** * Hook to define and process slash commands (e.g., /help, /clear). @@ -164,6 +165,7 @@ export const useSlashCommandProcessor = ( const controller = new AbortController(); const load = async () => { const loaders = [ + new McpPromptLoader(config), new BuiltinCommandLoader(config), new FileCommandLoader(config), ]; @@ -246,82 +248,95 @@ export const useSlashCommandProcessor = ( args, }, }; - const result = await commandToExecute.action( - fullCommandContext, - args, - ); + try { + const result = await commandToExecute.action( + fullCommandContext, + args, + ); - if (result) { - switch (result.type) { - case 'tool': - return { - type: 'schedule_tool', - toolName: result.toolName, - toolArgs: result.toolArgs, - }; - case 'message': - addItem( - { - type: - result.messageType === 'error' - ? MessageType.ERROR - : MessageType.INFO, - text: result.content, - }, - Date.now(), - ); - return { type: 'handled' }; - case 'dialog': - switch (result.dialog) { - case 'help': - setShowHelp(true); - return { type: 'handled' }; - case 'auth': - openAuthDialog(); - return { type: 'handled' }; - case 'theme': - openThemeDialog(); - return { type: 'handled' }; - case 'editor': - openEditorDialog(); - return { type: 'handled' }; - case 'privacy': - openPrivacyNotice(); - return { type: 'handled' }; - default: { - const unhandled: never = result.dialog; - throw new Error( - `Unhandled slash command result: ${unhandled}`, - ); + if (result) { + switch (result.type) { + case 'tool': + return { + type: 'schedule_tool', + toolName: result.toolName, + toolArgs: result.toolArgs, + }; + case 'message': + addItem( + { + type: + result.messageType === 'error' + ? MessageType.ERROR + : MessageType.INFO, + text: result.content, + }, + Date.now(), + ); + return { type: 'handled' }; + case 'dialog': + switch (result.dialog) { + case 'help': + setShowHelp(true); + return { type: 'handled' }; + case 'auth': + openAuthDialog(); + return { type: 'handled' }; + case 'theme': + openThemeDialog(); + return { type: 'handled' }; + case 'editor': + openEditorDialog(); + return { type: 'handled' }; + case 'privacy': + openPrivacyNotice(); + return { type: 'handled' }; + default: { + const unhandled: never = result.dialog; + throw new Error( + `Unhandled slash command result: ${unhandled}`, + ); + } } + case 'load_history': { + await config + ?.getGeminiClient() + ?.setHistory(result.clientHistory); + fullCommandContext.ui.clear(); + result.history.forEach((item, index) => { + fullCommandContext.ui.addItem(item, index); + }); + return { type: 'handled' }; } - case 'load_history': { - await config - ?.getGeminiClient() - ?.setHistory(result.clientHistory); - fullCommandContext.ui.clear(); - result.history.forEach((item, index) => { - fullCommandContext.ui.addItem(item, index); - }); - return { type: 'handled' }; - } - case 'quit': - setQuittingMessages(result.messages); - setTimeout(() => { - process.exit(0); - }, 100); - return { type: 'handled' }; + case 'quit': + setQuittingMessages(result.messages); + setTimeout(() => { + process.exit(0); + }, 100); + return { type: 'handled' }; - case 'submit_prompt': - return { - type: 'submit_prompt', - content: result.content, - }; - default: { - const unhandled: never = result; - throw new Error(`Unhandled slash command result: ${unhandled}`); + case 'submit_prompt': + return { + type: 'submit_prompt', + content: result.content, + }; + default: { + const unhandled: never = result; + throw new Error( + `Unhandled slash command result: ${unhandled}`, + ); + } } } + } catch (e) { + addItem( + { + type: MessageType.ERROR, + text: e instanceof Error ? e.message : String(e), + }, + Date.now(), + ); + return { type: 'handled' }; } return { type: 'handled' }; diff --git a/packages/cli/src/ui/hooks/useCompletion.test.ts b/packages/cli/src/ui/hooks/useCompletion.test.ts index cd525435..da6a7ab3 100644 --- a/packages/cli/src/ui/hooks/useCompletion.test.ts +++ b/packages/cli/src/ui/hooks/useCompletion.test.ts @@ -1100,7 +1100,7 @@ describe('useCompletion', () => { result.current.handleAutocomplete(0); }); - expect(mockBuffer.setText).toHaveBeenCalledWith('/memory'); + expect(mockBuffer.setText).toHaveBeenCalledWith('/memory '); }); it('should append a sub-command when the parent is complete', () => { @@ -1145,7 +1145,7 @@ describe('useCompletion', () => { result.current.handleAutocomplete(1); // index 1 is 'add' }); - expect(mockBuffer.setText).toHaveBeenCalledWith('/memory add'); + expect(mockBuffer.setText).toHaveBeenCalledWith('/memory add '); }); it('should complete a command with an alternative name', () => { @@ -1190,7 +1190,7 @@ describe('useCompletion', () => { result.current.handleAutocomplete(0); }); - expect(mockBuffer.setText).toHaveBeenCalledWith('/help'); + expect(mockBuffer.setText).toHaveBeenCalledWith('/help '); }); it('should complete a file path', async () => { diff --git a/packages/cli/src/ui/hooks/useCompletion.ts b/packages/cli/src/ui/hooks/useCompletion.ts index dc45222d..10724c21 100644 --- a/packages/cli/src/ui/hooks/useCompletion.ts +++ b/packages/cli/src/ui/hooks/useCompletion.ts @@ -638,10 +638,17 @@ export function useCompletion( // Determine the base path of the command. // - If there's a trailing space, the whole command is the base. // - If it's a known parent path, the whole command is the base. + // - If the last part is a complete argument, the whole command is the base. // - Otherwise, the base is everything EXCEPT the last partial part. + const lastPart = parts.length > 0 ? parts[parts.length - 1] : ''; + const isLastPartACompleteArg = + lastPart.startsWith('--') && lastPart.includes('='); + const basePath = - hasTrailingSpace || isParentPath ? parts : parts.slice(0, -1); - const newValue = `/${[...basePath, suggestion].join(' ')}`; + hasTrailingSpace || isParentPath || isLastPartACompleteArg + ? parts + : parts.slice(0, -1); + const newValue = `/${[...basePath, suggestion].join(' ')} `; buffer.setText(newValue); } else { |
