diff options
| author | Abhi <[email protected]> | 2025-07-27 02:00:26 -0400 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-07-27 06:00:26 +0000 |
| commit | 576cebc9282cfbe57d45321105d72cc61597ce9b (patch) | |
| tree | 374dd97245761fe5c40ee87a9b4d5674a26344cf /packages/cli/src/ui/hooks | |
| parent | 9e61b3510c0cd7f333f40f68e87d981aff19aab1 (diff) | |
feat: Add Shell Command Execution to Custom Commands (#4917)
Diffstat (limited to 'packages/cli/src/ui/hooks')
| -rw-r--r-- | packages/cli/src/ui/hooks/slashCommandProcessor.test.ts | 212 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/slashCommandProcessor.ts | 235 |
2 files changed, 367 insertions, 80 deletions
diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts index ac9b79ec..5b367cd4 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts @@ -42,8 +42,13 @@ vi.mock('../contexts/SessionContext.js', () => ({ import { act, renderHook, waitFor } from '@testing-library/react'; import { vi, describe, it, expect, beforeEach, type Mock } from 'vitest'; import { useSlashCommandProcessor } from './slashCommandProcessor.js'; -import { CommandKind, SlashCommand } from '../commands/types.js'; -import { Config } from '@google/gemini-cli-core'; +import { + CommandContext, + CommandKind, + ConfirmShellCommandsActionReturn, + SlashCommand, +} from '../commands/types.js'; +import { Config, ToolConfirmationOutcome } from '@google/gemini-cli-core'; import { LoadedSettings } from '../../config/settings.js'; import { MessageType } from '../types.js'; import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js'; @@ -90,6 +95,7 @@ describe('useSlashCommandProcessor', () => { builtinCommands: SlashCommand[] = [], fileCommands: SlashCommand[] = [], mcpCommands: SlashCommand[] = [], + setIsProcessing = vi.fn(), ) => { mockBuiltinLoadCommands.mockResolvedValue(Object.freeze(builtinCommands)); mockFileLoadCommands.mockResolvedValue(Object.freeze(fileCommands)); @@ -112,6 +118,7 @@ describe('useSlashCommandProcessor', () => { mockSetQuittingMessages, vi.fn(), // openPrivacyNotice vi.fn(), // toggleVimEnabled + setIsProcessing, ), ); @@ -275,6 +282,32 @@ describe('useSlashCommandProcessor', () => { 'with args', ); }); + + it('should set isProcessing to true during execution and false afterwards', async () => { + const mockSetIsProcessing = vi.fn(); + const command = createTestCommand({ + name: 'long-running', + action: () => new Promise((resolve) => setTimeout(resolve, 50)), + }); + + const result = setupProcessorHook([command], [], [], mockSetIsProcessing); + await waitFor(() => expect(result.current.slashCommands).toHaveLength(1)); + + const executionPromise = act(async () => { + await result.current.handleSlashCommand('/long-running'); + }); + + // It should be true immediately after starting + expect(mockSetIsProcessing).toHaveBeenCalledWith(true); + // It should not have been called with false yet + expect(mockSetIsProcessing).not.toHaveBeenCalledWith(false); + + await executionPromise; + + // After the promise resolves, it should be called with false + expect(mockSetIsProcessing).toHaveBeenCalledWith(false); + expect(mockSetIsProcessing).toHaveBeenCalledTimes(2); + }); }); describe('Action Result Handling', () => { @@ -417,6 +450,176 @@ describe('useSlashCommandProcessor', () => { }); }); + describe('Shell Command Confirmation Flow', () => { + // Use a generic vi.fn() for the action. We will change its behavior in each test. + const mockCommandAction = vi.fn(); + + const shellCommand = createTestCommand({ + name: 'shellcmd', + action: mockCommandAction, + }); + + beforeEach(() => { + // Reset the mock before each test + mockCommandAction.mockClear(); + + // Default behavior: request confirmation + mockCommandAction.mockResolvedValue({ + type: 'confirm_shell_commands', + commandsToConfirm: ['rm -rf /'], + originalInvocation: { raw: '/shellcmd' }, + } as ConfirmShellCommandsActionReturn); + }); + + it('should set confirmation request when action returns confirm_shell_commands', async () => { + const result = setupProcessorHook([shellCommand]); + await waitFor(() => expect(result.current.slashCommands).toHaveLength(1)); + + // This is intentionally not awaited, because the promise it returns + // will not resolve until the user responds to the confirmation. + act(() => { + result.current.handleSlashCommand('/shellcmd'); + }); + + // We now wait for the state to be updated with the request. + await waitFor(() => { + expect(result.current.shellConfirmationRequest).not.toBeNull(); + }); + + expect(result.current.shellConfirmationRequest?.commands).toEqual([ + 'rm -rf /', + ]); + }); + + it('should do nothing if user cancels confirmation', async () => { + const result = setupProcessorHook([shellCommand]); + await waitFor(() => expect(result.current.slashCommands).toHaveLength(1)); + + act(() => { + result.current.handleSlashCommand('/shellcmd'); + }); + + // Wait for the confirmation dialog to be set + await waitFor(() => { + expect(result.current.shellConfirmationRequest).not.toBeNull(); + }); + + const onConfirm = result.current.shellConfirmationRequest?.onConfirm; + expect(onConfirm).toBeDefined(); + + // Change the mock action's behavior for a potential second run. + // If the test is flawed, this will be called, and we can detect it. + mockCommandAction.mockResolvedValue({ + type: 'message', + messageType: 'info', + content: 'This should not be called', + }); + + await act(async () => { + onConfirm!(ToolConfirmationOutcome.Cancel, []); // Pass empty array for safety + }); + + expect(result.current.shellConfirmationRequest).toBeNull(); + // Verify the action was only called the initial time. + expect(mockCommandAction).toHaveBeenCalledTimes(1); + }); + + it('should re-run command with one-time allowlist on "Proceed Once"', async () => { + const result = setupProcessorHook([shellCommand]); + await waitFor(() => expect(result.current.slashCommands).toHaveLength(1)); + + act(() => { + result.current.handleSlashCommand('/shellcmd'); + }); + await waitFor(() => { + expect(result.current.shellConfirmationRequest).not.toBeNull(); + }); + + const onConfirm = result.current.shellConfirmationRequest?.onConfirm; + + // **Change the mock's behavior for the SECOND run.** + // This is the key to testing the outcome. + mockCommandAction.mockResolvedValue({ + type: 'message', + messageType: 'info', + content: 'Success!', + }); + + await act(async () => { + onConfirm!(ToolConfirmationOutcome.ProceedOnce, ['rm -rf /']); + }); + + expect(result.current.shellConfirmationRequest).toBeNull(); + + // The action should have been called twice (initial + re-run). + await waitFor(() => { + expect(mockCommandAction).toHaveBeenCalledTimes(2); + }); + + // We can inspect the context of the second call to ensure the one-time list was used. + const secondCallContext = mockCommandAction.mock + .calls[1][0] as CommandContext; + expect( + secondCallContext.session.sessionShellAllowlist.has('rm -rf /'), + ).toBe(true); + + // Verify the final success message was added. + expect(mockAddItem).toHaveBeenCalledWith( + { type: MessageType.INFO, text: 'Success!' }, + expect.any(Number), + ); + + // Verify the session-wide allowlist was NOT permanently updated. + // Re-render the hook by calling a no-op command to get the latest context. + await act(async () => { + result.current.handleSlashCommand('/no-op'); + }); + const finalContext = result.current.commandContext; + expect(finalContext.session.sessionShellAllowlist.size).toBe(0); + }); + + it('should re-run command and update session allowlist on "Proceed Always"', async () => { + const result = setupProcessorHook([shellCommand]); + await waitFor(() => expect(result.current.slashCommands).toHaveLength(1)); + + act(() => { + result.current.handleSlashCommand('/shellcmd'); + }); + await waitFor(() => { + expect(result.current.shellConfirmationRequest).not.toBeNull(); + }); + + const onConfirm = result.current.shellConfirmationRequest?.onConfirm; + mockCommandAction.mockResolvedValue({ + type: 'message', + messageType: 'info', + content: 'Success!', + }); + + await act(async () => { + onConfirm!(ToolConfirmationOutcome.ProceedAlways, ['rm -rf /']); + }); + + expect(result.current.shellConfirmationRequest).toBeNull(); + await waitFor(() => { + expect(mockCommandAction).toHaveBeenCalledTimes(2); + }); + + expect(mockAddItem).toHaveBeenCalledWith( + { type: MessageType.INFO, text: 'Success!' }, + expect.any(Number), + ); + + // Check that the session-wide allowlist WAS updated. + await waitFor(() => { + const finalContext = result.current.commandContext; + expect(finalContext.session.sessionShellAllowlist.has('rm -rf /')).toBe( + true, + ); + }); + }); + }); + describe('Command Parsing and Matching', () => { it('should be case-sensitive', async () => { const command = createTestCommand({ name: 'test' }); @@ -583,7 +786,7 @@ describe('useSlashCommandProcessor', () => { }); describe('Lifecycle', () => { - it('should abort command loading when the hook unmounts', async () => { + it('should abort command loading when the hook unmounts', () => { const abortSpy = vi.spyOn(AbortController.prototype, 'abort'); const { unmount } = renderHook(() => useSlashCommandProcessor( @@ -597,10 +800,11 @@ describe('useSlashCommandProcessor', () => { vi.fn(), // onDebugMessage vi.fn(), // openThemeDialog mockOpenAuthDialog, - vi.fn(), // openEditorDialog + vi.fn(), // openEditorDialog, vi.fn(), // toggleCorgiMode mockSetQuittingMessages, vi.fn(), // openPrivacyNotice + vi.fn(), // toggleVimEnabled ), ); diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index 46b49329..be32de11 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -9,7 +9,12 @@ import { type PartListUnion } from '@google/genai'; import process from 'node:process'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { useStateAndRef } from './useStateAndRef.js'; -import { Config, GitService, Logger } from '@google/gemini-cli-core'; +import { + Config, + GitService, + Logger, + ToolConfirmationOutcome, +} from '@google/gemini-cli-core'; import { useSessionStats } from '../contexts/SessionContext.js'; import { Message, @@ -44,9 +49,21 @@ export const useSlashCommandProcessor = ( setQuittingMessages: (message: HistoryItem[]) => void, openPrivacyNotice: () => void, toggleVimEnabled: () => Promise<boolean>, + setIsProcessing: (isProcessing: boolean) => void, ) => { const session = useSessionStats(); const [commands, setCommands] = useState<readonly SlashCommand[]>([]); + const [shellConfirmationRequest, setShellConfirmationRequest] = + useState<null | { + commands: string[]; + onConfirm: ( + outcome: ToolConfirmationOutcome, + approvedCommands?: string[], + ) => void; + }>(null); + const [sessionShellAllowlist, setSessionShellAllowlist] = useState( + new Set<string>(), + ); const gitService = useMemo(() => { if (!config?.getProjectRoot()) { return; @@ -144,6 +161,7 @@ export const useSlashCommandProcessor = ( }, session: { stats: session.stats, + sessionShellAllowlist, }, }), [ @@ -161,6 +179,7 @@ export const useSlashCommandProcessor = ( setPendingCompressionItem, toggleCorgiMode, toggleVimEnabled, + sessionShellAllowlist, ], ); @@ -189,69 +208,87 @@ export const useSlashCommandProcessor = ( const handleSlashCommand = useCallback( async ( rawQuery: PartListUnion, + oneTimeShellAllowlist?: Set<string>, ): Promise<SlashCommandProcessorResult | false> => { - if (typeof rawQuery !== 'string') { - return false; - } + setIsProcessing(true); + try { + if (typeof rawQuery !== 'string') { + return false; + } - const trimmed = rawQuery.trim(); - if (!trimmed.startsWith('/') && !trimmed.startsWith('?')) { - return false; - } + const trimmed = rawQuery.trim(); + if (!trimmed.startsWith('/') && !trimmed.startsWith('?')) { + return false; + } - const userMessageTimestamp = Date.now(); - addItem({ type: MessageType.USER, text: trimmed }, userMessageTimestamp); + const userMessageTimestamp = Date.now(); + addItem( + { type: MessageType.USER, text: trimmed }, + userMessageTimestamp, + ); - const parts = trimmed.substring(1).trim().split(/\s+/); - const commandPath = parts.filter((p) => p); // The parts of the command, e.g., ['memory', 'add'] + const parts = trimmed.substring(1).trim().split(/\s+/); + const commandPath = parts.filter((p) => p); // The parts of the command, e.g., ['memory', 'add'] - let currentCommands = commands; - let commandToExecute: SlashCommand | undefined; - let pathIndex = 0; + let currentCommands = commands; + let commandToExecute: SlashCommand | undefined; + let pathIndex = 0; - for (const part of commandPath) { - // TODO: For better performance and architectural clarity, this two-pass - // search could be replaced. A more optimal approach would be to - // pre-compute a single lookup map in `CommandService.ts` that resolves - // all name and alias conflicts during the initial loading phase. The - // processor would then perform a single, fast lookup on that map. + for (const part of commandPath) { + // TODO: For better performance and architectural clarity, this two-pass + // search could be replaced. A more optimal approach would be to + // pre-compute a single lookup map in `CommandService.ts` that resolves + // all name and alias conflicts during the initial loading phase. The + // processor would then perform a single, fast lookup on that map. - // First pass: check for an exact match on the primary command name. - let foundCommand = currentCommands.find((cmd) => cmd.name === part); + // First pass: check for an exact match on the primary command name. + let foundCommand = currentCommands.find((cmd) => cmd.name === part); - // Second pass: if no primary name matches, check for an alias. - if (!foundCommand) { - foundCommand = currentCommands.find((cmd) => - cmd.altNames?.includes(part), - ); - } + // Second pass: if no primary name matches, check for an alias. + if (!foundCommand) { + foundCommand = currentCommands.find((cmd) => + cmd.altNames?.includes(part), + ); + } - if (foundCommand) { - commandToExecute = foundCommand; - pathIndex++; - if (foundCommand.subCommands) { - currentCommands = foundCommand.subCommands; + if (foundCommand) { + commandToExecute = foundCommand; + pathIndex++; + if (foundCommand.subCommands) { + currentCommands = foundCommand.subCommands; + } else { + break; + } } else { break; } - } else { - break; } - } - if (commandToExecute) { - const args = parts.slice(pathIndex).join(' '); + if (commandToExecute) { + const args = parts.slice(pathIndex).join(' '); + + if (commandToExecute.action) { + const fullCommandContext: CommandContext = { + ...commandContext, + invocation: { + raw: trimmed, + name: commandToExecute.name, + args, + }, + }; + + // If a one-time list is provided for a "Proceed" action, temporarily + // augment the session allowlist for this single execution. + if (oneTimeShellAllowlist && oneTimeShellAllowlist.size > 0) { + fullCommandContext.session = { + ...fullCommandContext.session, + sessionShellAllowlist: new Set([ + ...fullCommandContext.session.sessionShellAllowlist, + ...oneTimeShellAllowlist, + ]), + }; + } - if (commandToExecute.action) { - const fullCommandContext: CommandContext = { - ...commandContext, - invocation: { - raw: trimmed, - name: commandToExecute.name, - args, - }, - }; - try { const result = await commandToExecute.action( fullCommandContext, args, @@ -323,6 +360,46 @@ export const useSlashCommandProcessor = ( type: 'submit_prompt', content: result.content, }; + case 'confirm_shell_commands': { + const { outcome, approvedCommands } = await new Promise<{ + outcome: ToolConfirmationOutcome; + approvedCommands?: string[]; + }>((resolve) => { + setShellConfirmationRequest({ + commands: result.commandsToConfirm, + onConfirm: ( + resolvedOutcome, + resolvedApprovedCommands, + ) => { + setShellConfirmationRequest(null); // Close the dialog + resolve({ + outcome: resolvedOutcome, + approvedCommands: resolvedApprovedCommands, + }); + }, + }); + }); + + if ( + outcome === ToolConfirmationOutcome.Cancel || + !approvedCommands || + approvedCommands.length === 0 + ) { + return { type: 'handled' }; + } + + if (outcome === ToolConfirmationOutcome.ProceedAlways) { + setSessionShellAllowlist( + (prev) => new Set([...prev, ...approvedCommands]), + ); + } + + return await handleSlashCommand( + result.originalInvocation.raw, + // Pass the approved commands as a one-time grant for this execution. + new Set(approvedCommands), + ); + } default: { const unhandled: never = result; throw new Error( @@ -331,37 +408,39 @@ export const useSlashCommandProcessor = ( } } } - } catch (e) { - addItem( - { - type: MessageType.ERROR, - text: e instanceof Error ? e.message : String(e), - }, - Date.now(), - ); + + return { type: 'handled' }; + } else if (commandToExecute.subCommands) { + const helpText = `Command '/${commandToExecute.name}' requires a subcommand. Available:\n${commandToExecute.subCommands + .map((sc) => ` - ${sc.name}: ${sc.description || ''}`) + .join('\n')}`; + addMessage({ + type: MessageType.INFO, + content: helpText, + timestamp: new Date(), + }); return { type: 'handled' }; } - - return { type: 'handled' }; - } else if (commandToExecute.subCommands) { - const helpText = `Command '/${commandToExecute.name}' requires a subcommand. Available:\n${commandToExecute.subCommands - .map((sc) => ` - ${sc.name}: ${sc.description || ''}`) - .join('\n')}`; - addMessage({ - type: MessageType.INFO, - content: helpText, - timestamp: new Date(), - }); - return { type: 'handled' }; } - } - addMessage({ - type: MessageType.ERROR, - content: `Unknown command: ${trimmed}`, - timestamp: new Date(), - }); - return { type: 'handled' }; + addMessage({ + type: MessageType.ERROR, + content: `Unknown command: ${trimmed}`, + timestamp: new Date(), + }); + return { type: 'handled' }; + } catch (e) { + addItem( + { + type: MessageType.ERROR, + text: e instanceof Error ? e.message : String(e), + }, + Date.now(), + ); + return { type: 'handled' }; + } finally { + setIsProcessing(false); + } }, [ config, @@ -375,6 +454,9 @@ export const useSlashCommandProcessor = ( openPrivacyNotice, openEditorDialog, setQuittingMessages, + setShellConfirmationRequest, + setSessionShellAllowlist, + setIsProcessing, ], ); @@ -383,5 +465,6 @@ export const useSlashCommandProcessor = ( slashCommands: commands, pendingHistoryItems, commandContext, + shellConfirmationRequest, }; }; |
