diff options
Diffstat (limited to 'packages/cli/src/ui/hooks')
| -rw-r--r-- | packages/cli/src/ui/hooks/slashCommandProcessor.test.ts | 78 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/slashCommandProcessor.ts | 155 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useCompletion.test.ts | 6 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useCompletion.ts | 11 |
4 files changed, 175 insertions, 75 deletions
diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts index 84eeb033..d308af46 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts @@ -28,6 +28,13 @@ vi.mock('../../services/FileCommandLoader.js', () => ({ })), })); +const mockMcpLoadCommands = vi.fn(); +vi.mock('../../services/McpPromptLoader.js', () => ({ + McpPromptLoader: vi.fn().mockImplementation(() => ({ + loadCommands: mockMcpLoadCommands, + })), +})); + vi.mock('../contexts/SessionContext.js', () => ({ useSessionStats: vi.fn(() => ({ stats: {} })), })); @@ -41,6 +48,7 @@ import { LoadedSettings } from '../../config/settings.js'; import { MessageType } from '../types.js'; import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js'; import { FileCommandLoader } from '../../services/FileCommandLoader.js'; +import { McpPromptLoader } from '../../services/McpPromptLoader.js'; const createTestCommand = ( overrides: Partial<SlashCommand>, @@ -75,14 +83,17 @@ describe('useSlashCommandProcessor', () => { (vi.mocked(BuiltinCommandLoader) as Mock).mockClear(); mockBuiltinLoadCommands.mockResolvedValue([]); mockFileLoadCommands.mockResolvedValue([]); + mockMcpLoadCommands.mockResolvedValue([]); }); const setupProcessorHook = ( builtinCommands: SlashCommand[] = [], fileCommands: SlashCommand[] = [], + mcpCommands: SlashCommand[] = [], ) => { mockBuiltinLoadCommands.mockResolvedValue(Object.freeze(builtinCommands)); mockFileLoadCommands.mockResolvedValue(Object.freeze(fileCommands)); + mockMcpLoadCommands.mockResolvedValue(Object.freeze(mcpCommands)); const { result } = renderHook(() => useSlashCommandProcessor( @@ -111,6 +122,7 @@ describe('useSlashCommandProcessor', () => { setupProcessorHook(); expect(BuiltinCommandLoader).toHaveBeenCalledWith(mockConfig); expect(FileCommandLoader).toHaveBeenCalledWith(mockConfig); + expect(McpPromptLoader).toHaveBeenCalledWith(mockConfig); }); it('should call loadCommands and populate state after mounting', async () => { @@ -124,6 +136,7 @@ describe('useSlashCommandProcessor', () => { expect(result.current.slashCommands[0]?.name).toBe('test'); expect(mockBuiltinLoadCommands).toHaveBeenCalledTimes(1); expect(mockFileLoadCommands).toHaveBeenCalledTimes(1); + expect(mockMcpLoadCommands).toHaveBeenCalledTimes(1); }); it('should provide an immutable array of commands to consumers', async () => { @@ -369,6 +382,38 @@ describe('useSlashCommandProcessor', () => { expect.any(Number), ); }); + + it('should handle "submit_prompt" action returned from a mcp-based command', async () => { + const mcpCommand = createTestCommand( + { + name: 'mcpcmd', + description: 'A command from mcp', + action: async () => ({ + type: 'submit_prompt', + content: 'The actual prompt from the mcp command.', + }), + }, + CommandKind.MCP_PROMPT, + ); + + const result = setupProcessorHook([], [], [mcpCommand]); + await waitFor(() => expect(result.current.slashCommands).toHaveLength(1)); + + let actionResult; + await act(async () => { + actionResult = await result.current.handleSlashCommand('/mcpcmd'); + }); + + expect(actionResult).toEqual({ + type: 'submit_prompt', + content: 'The actual prompt from the mcp command.', + }); + + expect(mockAddItem).toHaveBeenCalledWith( + { type: MessageType.USER, text: '/mcpcmd' }, + expect.any(Number), + ); + }); }); describe('Command Parsing and Matching', () => { @@ -441,6 +486,39 @@ describe('useSlashCommandProcessor', () => { }); describe('Command Precedence', () => { + it('should override mcp-based commands with file-based commands of the same name', async () => { + const mcpAction = vi.fn(); + const fileAction = vi.fn(); + + const mcpCommand = createTestCommand( + { + name: 'override', + description: 'mcp', + action: mcpAction, + }, + CommandKind.MCP_PROMPT, + ); + const fileCommand = createTestCommand( + { name: 'override', description: 'file', action: fileAction }, + CommandKind.FILE, + ); + + const result = setupProcessorHook([], [fileCommand], [mcpCommand]); + + await waitFor(() => { + // The service should only return one command with the name 'override' + expect(result.current.slashCommands).toHaveLength(1); + }); + + await act(async () => { + await result.current.handleSlashCommand('/override'); + }); + + // Only the file-based command's action should be called. + expect(fileAction).toHaveBeenCalledTimes(1); + expect(mcpAction).not.toHaveBeenCalled(); + }); + it('should prioritize a command with a primary name over a command with a matching alias', async () => { const quitAction = vi.fn(); const exitAction = vi.fn(); diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index fa2b0b12..9e9dc21c 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -23,6 +23,7 @@ import { type CommandContext, type SlashCommand } from '../commands/types.js'; import { CommandService } from '../../services/CommandService.js'; import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js'; import { FileCommandLoader } from '../../services/FileCommandLoader.js'; +import { McpPromptLoader } from '../../services/McpPromptLoader.js'; /** * Hook to define and process slash commands (e.g., /help, /clear). @@ -164,6 +165,7 @@ export const useSlashCommandProcessor = ( const controller = new AbortController(); const load = async () => { const loaders = [ + new McpPromptLoader(config), new BuiltinCommandLoader(config), new FileCommandLoader(config), ]; @@ -246,82 +248,95 @@ export const useSlashCommandProcessor = ( args, }, }; - const result = await commandToExecute.action( - fullCommandContext, - args, - ); + try { + const result = await commandToExecute.action( + fullCommandContext, + args, + ); - if (result) { - switch (result.type) { - case 'tool': - return { - type: 'schedule_tool', - toolName: result.toolName, - toolArgs: result.toolArgs, - }; - case 'message': - addItem( - { - type: - result.messageType === 'error' - ? MessageType.ERROR - : MessageType.INFO, - text: result.content, - }, - Date.now(), - ); - return { type: 'handled' }; - case 'dialog': - switch (result.dialog) { - case 'help': - setShowHelp(true); - return { type: 'handled' }; - case 'auth': - openAuthDialog(); - return { type: 'handled' }; - case 'theme': - openThemeDialog(); - return { type: 'handled' }; - case 'editor': - openEditorDialog(); - return { type: 'handled' }; - case 'privacy': - openPrivacyNotice(); - return { type: 'handled' }; - default: { - const unhandled: never = result.dialog; - throw new Error( - `Unhandled slash command result: ${unhandled}`, - ); + if (result) { + switch (result.type) { + case 'tool': + return { + type: 'schedule_tool', + toolName: result.toolName, + toolArgs: result.toolArgs, + }; + case 'message': + addItem( + { + type: + result.messageType === 'error' + ? MessageType.ERROR + : MessageType.INFO, + text: result.content, + }, + Date.now(), + ); + return { type: 'handled' }; + case 'dialog': + switch (result.dialog) { + case 'help': + setShowHelp(true); + return { type: 'handled' }; + case 'auth': + openAuthDialog(); + return { type: 'handled' }; + case 'theme': + openThemeDialog(); + return { type: 'handled' }; + case 'editor': + openEditorDialog(); + return { type: 'handled' }; + case 'privacy': + openPrivacyNotice(); + return { type: 'handled' }; + default: { + const unhandled: never = result.dialog; + throw new Error( + `Unhandled slash command result: ${unhandled}`, + ); + } } + case 'load_history': { + await config + ?.getGeminiClient() + ?.setHistory(result.clientHistory); + fullCommandContext.ui.clear(); + result.history.forEach((item, index) => { + fullCommandContext.ui.addItem(item, index); + }); + return { type: 'handled' }; } - case 'load_history': { - await config - ?.getGeminiClient() - ?.setHistory(result.clientHistory); - fullCommandContext.ui.clear(); - result.history.forEach((item, index) => { - fullCommandContext.ui.addItem(item, index); - }); - return { type: 'handled' }; - } - case 'quit': - setQuittingMessages(result.messages); - setTimeout(() => { - process.exit(0); - }, 100); - return { type: 'handled' }; + case 'quit': + setQuittingMessages(result.messages); + setTimeout(() => { + process.exit(0); + }, 100); + return { type: 'handled' }; - case 'submit_prompt': - return { - type: 'submit_prompt', - content: result.content, - }; - default: { - const unhandled: never = result; - throw new Error(`Unhandled slash command result: ${unhandled}`); + case 'submit_prompt': + return { + type: 'submit_prompt', + content: result.content, + }; + default: { + const unhandled: never = result; + throw new Error( + `Unhandled slash command result: ${unhandled}`, + ); + } } } + } catch (e) { + addItem( + { + type: MessageType.ERROR, + text: e instanceof Error ? e.message : String(e), + }, + Date.now(), + ); + return { type: 'handled' }; } return { type: 'handled' }; diff --git a/packages/cli/src/ui/hooks/useCompletion.test.ts b/packages/cli/src/ui/hooks/useCompletion.test.ts index cd525435..da6a7ab3 100644 --- a/packages/cli/src/ui/hooks/useCompletion.test.ts +++ b/packages/cli/src/ui/hooks/useCompletion.test.ts @@ -1100,7 +1100,7 @@ describe('useCompletion', () => { result.current.handleAutocomplete(0); }); - expect(mockBuffer.setText).toHaveBeenCalledWith('/memory'); + expect(mockBuffer.setText).toHaveBeenCalledWith('/memory '); }); it('should append a sub-command when the parent is complete', () => { @@ -1145,7 +1145,7 @@ describe('useCompletion', () => { result.current.handleAutocomplete(1); // index 1 is 'add' }); - expect(mockBuffer.setText).toHaveBeenCalledWith('/memory add'); + expect(mockBuffer.setText).toHaveBeenCalledWith('/memory add '); }); it('should complete a command with an alternative name', () => { @@ -1190,7 +1190,7 @@ describe('useCompletion', () => { result.current.handleAutocomplete(0); }); - expect(mockBuffer.setText).toHaveBeenCalledWith('/help'); + expect(mockBuffer.setText).toHaveBeenCalledWith('/help '); }); it('should complete a file path', async () => { diff --git a/packages/cli/src/ui/hooks/useCompletion.ts b/packages/cli/src/ui/hooks/useCompletion.ts index dc45222d..10724c21 100644 --- a/packages/cli/src/ui/hooks/useCompletion.ts +++ b/packages/cli/src/ui/hooks/useCompletion.ts @@ -638,10 +638,17 @@ export function useCompletion( // Determine the base path of the command. // - If there's a trailing space, the whole command is the base. // - If it's a known parent path, the whole command is the base. + // - If the last part is a complete argument, the whole command is the base. // - Otherwise, the base is everything EXCEPT the last partial part. + const lastPart = parts.length > 0 ? parts[parts.length - 1] : ''; + const isLastPartACompleteArg = + lastPart.startsWith('--') && lastPart.includes('='); + const basePath = - hasTrailingSpace || isParentPath ? parts : parts.slice(0, -1); - const newValue = `/${[...basePath, suggestion].join(' ')}`; + hasTrailingSpace || isParentPath || isLastPartACompleteArg + ? parts + : parts.slice(0, -1); + const newValue = `/${[...basePath, suggestion].join(' ')} `; buffer.setText(newValue); } else { |
