diff options
| author | Abhi <[email protected]> | 2025-07-27 02:00:26 -0400 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-07-27 06:00:26 +0000 |
| commit | 576cebc9282cfbe57d45321105d72cc61597ce9b (patch) | |
| tree | 374dd97245761fe5c40ee87a9b4d5674a26344cf /packages/cli/src | |
| parent | 9e61b3510c0cd7f333f40f68e87d981aff19aab1 (diff) | |
feat: Add Shell Command Execution to Custom Commands (#4917)
Diffstat (limited to 'packages/cli/src')
12 files changed, 1238 insertions, 94 deletions
diff --git a/packages/cli/src/services/FileCommandLoader.test.ts b/packages/cli/src/services/FileCommandLoader.test.ts index 0e3d781b..e3cbceb2 100644 --- a/packages/cli/src/services/FileCommandLoader.test.ts +++ b/packages/cli/src/services/FileCommandLoader.test.ts @@ -11,12 +11,68 @@ import { getUserCommandsDir, } from '@google/gemini-cli-core'; import mock from 'mock-fs'; -import { assert } from 'vitest'; +import { assert, vi } from 'vitest'; import { createMockCommandContext } from '../test-utils/mockCommandContext.js'; +import { + SHELL_INJECTION_TRIGGER, + SHORTHAND_ARGS_PLACEHOLDER, +} from './prompt-processors/types.js'; +import { + ConfirmationRequiredError, + ShellProcessor, +} from './prompt-processors/shellProcessor.js'; +import { ShorthandArgumentProcessor } from './prompt-processors/argumentProcessor.js'; + +const mockShellProcess = vi.hoisted(() => vi.fn()); +vi.mock('./prompt-processors/shellProcessor.js', () => ({ + ShellProcessor: vi.fn().mockImplementation(() => ({ + process: mockShellProcess, + })), + ConfirmationRequiredError: class extends Error { + constructor( + message: string, + public commandsToConfirm: string[], + ) { + super(message); + this.name = 'ConfirmationRequiredError'; + } + }, +})); + +vi.mock('./prompt-processors/argumentProcessor.js', async (importOriginal) => { + const original = + await importOriginal< + typeof import('./prompt-processors/argumentProcessor.js') + >(); + return { + ShorthandArgumentProcessor: vi + .fn() + .mockImplementation(() => new original.ShorthandArgumentProcessor()), + DefaultArgumentProcessor: vi + .fn() + .mockImplementation(() => new original.DefaultArgumentProcessor()), + }; +}); +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const original = + await importOriginal<typeof import('@google/gemini-cli-core')>(); + return { + ...original, + isCommandAllowed: vi.fn(), + ShellExecutionService: { + execute: vi.fn(), + }, + }; +}); describe('FileCommandLoader', () => { const signal: AbortSignal = new AbortController().signal; + beforeEach(() => { + vi.clearAllMocks(); + mockShellProcess.mockImplementation((prompt) => Promise.resolve(prompt)); + }); + afterEach(() => { mock.restore(); }); @@ -371,4 +427,180 @@ describe('FileCommandLoader', () => { } }); }); + + describe('Shell Processor Integration', () => { + it('instantiates ShellProcessor if the trigger is present', async () => { + const userCommandsDir = getUserCommandsDir(); + mock({ + [userCommandsDir]: { + 'shell.toml': `prompt = "Run this: ${SHELL_INJECTION_TRIGGER}echo hello}"`, + }, + }); + + const loader = new FileCommandLoader(null as unknown as Config); + await loader.loadCommands(signal); + + expect(ShellProcessor).toHaveBeenCalledWith('shell'); + }); + + it('does not instantiate ShellProcessor if trigger is missing', async () => { + const userCommandsDir = getUserCommandsDir(); + mock({ + [userCommandsDir]: { + 'regular.toml': `prompt = "Just a regular prompt"`, + }, + }); + + const loader = new FileCommandLoader(null as unknown as Config); + await loader.loadCommands(signal); + + expect(ShellProcessor).not.toHaveBeenCalled(); + }); + + it('returns a "submit_prompt" action if shell processing succeeds', async () => { + const userCommandsDir = getUserCommandsDir(); + mock({ + [userCommandsDir]: { + 'shell.toml': `prompt = "Run !{echo 'hello'}"`, + }, + }); + mockShellProcess.mockResolvedValue('Run hello'); + + const loader = new FileCommandLoader(null as unknown as Config); + const commands = await loader.loadCommands(signal); + const command = commands.find((c) => c.name === 'shell'); + expect(command).toBeDefined(); + + const result = await command!.action!( + createMockCommandContext({ + invocation: { raw: '/shell', name: 'shell', args: '' }, + }), + '', + ); + + expect(result?.type).toBe('submit_prompt'); + if (result?.type === 'submit_prompt') { + expect(result.content).toBe('Run hello'); + } + }); + + it('returns a "confirm_shell_commands" action if shell processing requires it', async () => { + const userCommandsDir = getUserCommandsDir(); + const rawInvocation = '/shell rm -rf /'; + mock({ + [userCommandsDir]: { + 'shell.toml': `prompt = "Run !{rm -rf /}"`, + }, + }); + + // Mock the processor to throw the specific error + const error = new ConfirmationRequiredError('Confirmation needed', [ + 'rm -rf /', + ]); + mockShellProcess.mockRejectedValue(error); + + const loader = new FileCommandLoader(null as unknown as Config); + const commands = await loader.loadCommands(signal); + const command = commands.find((c) => c.name === 'shell'); + expect(command).toBeDefined(); + + const result = await command!.action!( + createMockCommandContext({ + invocation: { raw: rawInvocation, name: 'shell', args: 'rm -rf /' }, + }), + 'rm -rf /', + ); + + expect(result?.type).toBe('confirm_shell_commands'); + if (result?.type === 'confirm_shell_commands') { + expect(result.commandsToConfirm).toEqual(['rm -rf /']); + expect(result.originalInvocation.raw).toBe(rawInvocation); + } + }); + + it('re-throws other errors from the processor', async () => { + const userCommandsDir = getUserCommandsDir(); + mock({ + [userCommandsDir]: { + 'shell.toml': `prompt = "Run !{something}"`, + }, + }); + + const genericError = new Error('Something else went wrong'); + mockShellProcess.mockRejectedValue(genericError); + + const loader = new FileCommandLoader(null as unknown as Config); + const commands = await loader.loadCommands(signal); + const command = commands.find((c) => c.name === 'shell'); + expect(command).toBeDefined(); + + await expect( + command!.action!( + createMockCommandContext({ + invocation: { raw: '/shell', name: 'shell', args: '' }, + }), + '', + ), + ).rejects.toThrow('Something else went wrong'); + }); + + it('assembles the processor pipeline in the correct order (Shell -> Argument)', async () => { + const userCommandsDir = getUserCommandsDir(); + mock({ + [userCommandsDir]: { + 'pipeline.toml': ` + prompt = "Shell says: ${SHELL_INJECTION_TRIGGER}echo foo} and user says: ${SHORTHAND_ARGS_PLACEHOLDER}" + `, + }, + }); + + // Mock the process methods to track call order + const argProcessMock = vi + .fn() + .mockImplementation((p) => `${p}-arg-processed`); + + // Redefine the mock for this specific test + mockShellProcess.mockImplementation((p) => + Promise.resolve(`${p}-shell-processed`), + ); + + vi.mocked(ShorthandArgumentProcessor).mockImplementation( + () => + ({ + process: argProcessMock, + }) as unknown as ShorthandArgumentProcessor, + ); + + const loader = new FileCommandLoader(null as unknown as Config); + const commands = await loader.loadCommands(signal); + const command = commands.find((c) => c.name === 'pipeline'); + expect(command).toBeDefined(); + + await command!.action!( + createMockCommandContext({ + invocation: { + raw: '/pipeline bar', + name: 'pipeline', + args: 'bar', + }, + }), + 'bar', + ); + + // Verify that the shell processor was called before the argument processor + expect(mockShellProcess.mock.invocationCallOrder[0]).toBeLessThan( + argProcessMock.mock.invocationCallOrder[0], + ); + + // Also verify the flow of the prompt through the processors + expect(mockShellProcess).toHaveBeenCalledWith( + expect.any(String), + expect.any(Object), + ); + expect(argProcessMock).toHaveBeenCalledWith( + expect.stringContaining('-shell-processed'), // It receives the output of the shell processor + expect.any(Object), + ); + }); + }); }); diff --git a/packages/cli/src/services/FileCommandLoader.ts b/packages/cli/src/services/FileCommandLoader.ts index 23d5af19..c96acead 100644 --- a/packages/cli/src/services/FileCommandLoader.ts +++ b/packages/cli/src/services/FileCommandLoader.ts @@ -19,7 +19,7 @@ import { CommandContext, CommandKind, SlashCommand, - SubmitPromptActionReturn, + SlashCommandActionReturn, } from '../ui/commands/types.js'; import { DefaultArgumentProcessor, @@ -28,7 +28,12 @@ import { import { IPromptProcessor, SHORTHAND_ARGS_PLACEHOLDER, + SHELL_INJECTION_TRIGGER, } from './prompt-processors/types.js'; +import { + ConfirmationRequiredError, + ShellProcessor, +} from './prompt-processors/shellProcessor.js'; /** * Defines the Zod schema for a command definition file. This serves as the @@ -172,6 +177,11 @@ export class FileCommandLoader implements ICommandLoader { const processors: IPromptProcessor[] = []; + // Add the Shell Processor if needed. + if (validDef.prompt.includes(SHELL_INJECTION_TRIGGER)) { + processors.push(new ShellProcessor(commandName)); + } + // The presence of '{{args}}' is the switch that determines the behavior. if (validDef.prompt.includes(SHORTHAND_ARGS_PLACEHOLDER)) { processors.push(new ShorthandArgumentProcessor()); @@ -188,7 +198,7 @@ export class FileCommandLoader implements ICommandLoader { action: async ( context: CommandContext, _args: string, - ): Promise<SubmitPromptActionReturn> => { + ): Promise<SlashCommandActionReturn> => { if (!context.invocation) { console.error( `[FileCommandLoader] Critical error: Command '${commandName}' was executed without invocation context.`, @@ -199,15 +209,31 @@ export class FileCommandLoader implements ICommandLoader { }; } - let processedPrompt = validDef.prompt; - for (const processor of processors) { - processedPrompt = await processor.process(processedPrompt, context); - } + try { + let processedPrompt = validDef.prompt; + for (const processor of processors) { + processedPrompt = await processor.process(processedPrompt, context); + } - return { - type: 'submit_prompt', - content: processedPrompt, - }; + return { + type: 'submit_prompt', + content: processedPrompt, + }; + } catch (e) { + // Check if it's our specific error type + if (e instanceof ConfirmationRequiredError) { + // Halt and request confirmation from the UI layer. + return { + type: 'confirm_shell_commands', + commandsToConfirm: e.commandsToConfirm, + originalInvocation: { + raw: context.invocation.raw, + }, + }; + } + // Re-throw other errors to be handled by the global error handler. + throw e; + } }, }; } diff --git a/packages/cli/src/services/prompt-processors/shellProcessor.test.ts b/packages/cli/src/services/prompt-processors/shellProcessor.test.ts new file mode 100644 index 00000000..a2883923 --- /dev/null +++ b/packages/cli/src/services/prompt-processors/shellProcessor.test.ts @@ -0,0 +1,300 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { vi, describe, it, expect, beforeEach } from 'vitest'; +import { ConfirmationRequiredError, ShellProcessor } from './shellProcessor.js'; +import { createMockCommandContext } from '../../test-utils/mockCommandContext.js'; +import { CommandContext } from '../../ui/commands/types.js'; +import { Config } from '@google/gemini-cli-core'; + +const mockCheckCommandPermissions = vi.hoisted(() => vi.fn()); +const mockShellExecute = vi.hoisted(() => vi.fn()); + +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const original = await importOriginal<object>(); + return { + ...original, + checkCommandPermissions: mockCheckCommandPermissions, + ShellExecutionService: { + execute: mockShellExecute, + }, + }; +}); + +describe('ShellProcessor', () => { + let context: CommandContext; + let mockConfig: Partial<Config>; + + beforeEach(() => { + vi.clearAllMocks(); + + mockConfig = { + getTargetDir: vi.fn().mockReturnValue('/test/dir'), + }; + + context = createMockCommandContext({ + services: { + config: mockConfig as Config, + }, + session: { + sessionShellAllowlist: new Set(), + }, + }); + + mockShellExecute.mockReturnValue({ + result: Promise.resolve({ + output: 'default shell output', + }), + }); + mockCheckCommandPermissions.mockReturnValue({ + allAllowed: true, + disallowedCommands: [], + }); + }); + + it('should not change the prompt if no shell injections are present', async () => { + const processor = new ShellProcessor('test-command'); + const prompt = 'This is a simple prompt with no injections.'; + const result = await processor.process(prompt, context); + expect(result).toBe(prompt); + expect(mockShellExecute).not.toHaveBeenCalled(); + }); + + it('should process a single valid shell injection if allowed', async () => { + const processor = new ShellProcessor('test-command'); + const prompt = 'The current status is: !{git status}'; + mockCheckCommandPermissions.mockReturnValue({ + allAllowed: true, + disallowedCommands: [], + }); + mockShellExecute.mockReturnValue({ + result: Promise.resolve({ output: 'On branch main' }), + }); + + const result = await processor.process(prompt, context); + + expect(mockCheckCommandPermissions).toHaveBeenCalledWith( + 'git status', + expect.any(Object), + context.session.sessionShellAllowlist, + ); + expect(mockShellExecute).toHaveBeenCalledWith( + 'git status', + expect.any(String), + expect.any(Function), + expect.any(Object), + ); + expect(result).toBe('The current status is: On branch main'); + }); + + it('should process multiple valid shell injections if all are allowed', async () => { + const processor = new ShellProcessor('test-command'); + const prompt = '!{git status} in !{pwd}'; + mockCheckCommandPermissions.mockReturnValue({ + allAllowed: true, + disallowedCommands: [], + }); + + mockShellExecute + .mockReturnValueOnce({ + result: Promise.resolve({ output: 'On branch main' }), + }) + .mockReturnValueOnce({ + result: Promise.resolve({ output: '/usr/home' }), + }); + + const result = await processor.process(prompt, context); + + expect(mockCheckCommandPermissions).toHaveBeenCalledTimes(2); + expect(mockShellExecute).toHaveBeenCalledTimes(2); + expect(result).toBe('On branch main in /usr/home'); + }); + + it('should throw ConfirmationRequiredError if a command is not allowed', async () => { + const processor = new ShellProcessor('test-command'); + const prompt = 'Do something dangerous: !{rm -rf /}'; + mockCheckCommandPermissions.mockReturnValue({ + allAllowed: false, + disallowedCommands: ['rm -rf /'], + }); + + await expect(processor.process(prompt, context)).rejects.toThrow( + ConfirmationRequiredError, + ); + }); + + it('should throw ConfirmationRequiredError with the correct command', async () => { + const processor = new ShellProcessor('test-command'); + const prompt = 'Do something dangerous: !{rm -rf /}'; + mockCheckCommandPermissions.mockReturnValue({ + allAllowed: false, + disallowedCommands: ['rm -rf /'], + }); + + try { + await processor.process(prompt, context); + // Fail if it doesn't throw + expect(true).toBe(false); + } catch (e) { + expect(e).toBeInstanceOf(ConfirmationRequiredError); + if (e instanceof ConfirmationRequiredError) { + expect(e.commandsToConfirm).toEqual(['rm -rf /']); + } + } + + expect(mockShellExecute).not.toHaveBeenCalled(); + }); + + it('should throw ConfirmationRequiredError with multiple commands if multiple are disallowed', async () => { + const processor = new ShellProcessor('test-command'); + const prompt = '!{cmd1} and !{cmd2}'; + mockCheckCommandPermissions.mockImplementation((cmd) => { + if (cmd === 'cmd1') { + return { allAllowed: false, disallowedCommands: ['cmd1'] }; + } + if (cmd === 'cmd2') { + return { allAllowed: false, disallowedCommands: ['cmd2'] }; + } + return { allAllowed: true, disallowedCommands: [] }; + }); + + try { + await processor.process(prompt, context); + // Fail if it doesn't throw + expect(true).toBe(false); + } catch (e) { + expect(e).toBeInstanceOf(ConfirmationRequiredError); + if (e instanceof ConfirmationRequiredError) { + expect(e.commandsToConfirm).toEqual(['cmd1', 'cmd2']); + } + } + }); + + it('should not execute any commands if at least one requires confirmation', async () => { + const processor = new ShellProcessor('test-command'); + const prompt = 'First: !{echo "hello"}, Second: !{rm -rf /}'; + + mockCheckCommandPermissions.mockImplementation((cmd) => { + if (cmd.includes('rm')) { + return { allAllowed: false, disallowedCommands: [cmd] }; + } + return { allAllowed: true, disallowedCommands: [] }; + }); + + await expect(processor.process(prompt, context)).rejects.toThrow( + ConfirmationRequiredError, + ); + + // Ensure no commands were executed because the pipeline was halted. + expect(mockShellExecute).not.toHaveBeenCalled(); + }); + + it('should only request confirmation for disallowed commands in a mixed prompt', async () => { + const processor = new ShellProcessor('test-command'); + const prompt = 'Allowed: !{ls -l}, Disallowed: !{rm -rf /}'; + + mockCheckCommandPermissions.mockImplementation((cmd) => ({ + allAllowed: !cmd.includes('rm'), + disallowedCommands: cmd.includes('rm') ? [cmd] : [], + })); + + try { + await processor.process(prompt, context); + expect.fail('Should have thrown ConfirmationRequiredError'); + } catch (e) { + expect(e).toBeInstanceOf(ConfirmationRequiredError); + if (e instanceof ConfirmationRequiredError) { + expect(e.commandsToConfirm).toEqual(['rm -rf /']); + } + } + }); + + it('should execute all commands if they are on the session allowlist', async () => { + const processor = new ShellProcessor('test-command'); + const prompt = 'Run !{cmd1} and !{cmd2}'; + + // Add commands to the session allowlist + context.session.sessionShellAllowlist = new Set(['cmd1', 'cmd2']); + + // checkCommandPermissions should now pass for these + mockCheckCommandPermissions.mockReturnValue({ + allAllowed: true, + disallowedCommands: [], + }); + + mockShellExecute + .mockReturnValueOnce({ result: Promise.resolve({ output: 'output1' }) }) + .mockReturnValueOnce({ result: Promise.resolve({ output: 'output2' }) }); + + const result = await processor.process(prompt, context); + + expect(mockCheckCommandPermissions).toHaveBeenCalledWith( + 'cmd1', + expect.any(Object), + context.session.sessionShellAllowlist, + ); + expect(mockCheckCommandPermissions).toHaveBeenCalledWith( + 'cmd2', + expect.any(Object), + context.session.sessionShellAllowlist, + ); + expect(mockShellExecute).toHaveBeenCalledTimes(2); + expect(result).toBe('Run output1 and output2'); + }); + + it('should trim whitespace from the command inside the injection', async () => { + const processor = new ShellProcessor('test-command'); + const prompt = 'Files: !{ ls -l }'; + mockCheckCommandPermissions.mockReturnValue({ + allAllowed: true, + disallowedCommands: [], + }); + mockShellExecute.mockReturnValue({ + result: Promise.resolve({ output: 'total 0' }), + }); + + await processor.process(prompt, context); + + expect(mockCheckCommandPermissions).toHaveBeenCalledWith( + 'ls -l', // Verifies that the command was trimmed + expect.any(Object), + context.session.sessionShellAllowlist, + ); + expect(mockShellExecute).toHaveBeenCalledWith( + 'ls -l', + expect.any(String), + expect.any(Function), + expect.any(Object), + ); + }); + + it('should handle an empty command inside the injection gracefully', async () => { + const processor = new ShellProcessor('test-command'); + const prompt = 'This is weird: !{}'; + mockCheckCommandPermissions.mockReturnValue({ + allAllowed: true, + disallowedCommands: [], + }); + mockShellExecute.mockReturnValue({ + result: Promise.resolve({ output: 'empty output' }), + }); + + const result = await processor.process(prompt, context); + + expect(mockCheckCommandPermissions).toHaveBeenCalledWith( + '', + expect.any(Object), + context.session.sessionShellAllowlist, + ); + expect(mockShellExecute).toHaveBeenCalledWith( + '', + expect.any(String), + expect.any(Function), + expect.any(Object), + ); + expect(result).toBe('This is weird: empty output'); + }); +}); diff --git a/packages/cli/src/services/prompt-processors/shellProcessor.ts b/packages/cli/src/services/prompt-processors/shellProcessor.ts new file mode 100644 index 00000000..bf811d66 --- /dev/null +++ b/packages/cli/src/services/prompt-processors/shellProcessor.ts @@ -0,0 +1,106 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + checkCommandPermissions, + ShellExecutionService, +} from '@google/gemini-cli-core'; + +import { CommandContext } from '../../ui/commands/types.js'; +import { IPromptProcessor } from './types.js'; + +export class ConfirmationRequiredError extends Error { + constructor( + message: string, + public commandsToConfirm: string[], + ) { + super(message); + this.name = 'ConfirmationRequiredError'; + } +} + +/** + * Finds all instances of shell command injections (`!{...}`) in a prompt, + * executes them, and replaces the injection site with the command's output. + * + * This processor ensures that only allowlisted commands are executed. If a + * disallowed command is found, it halts execution and reports an error. + */ +export class ShellProcessor implements IPromptProcessor { + /** + * A regular expression to find all instances of `!{...}`. The inner + * capture group extracts the command itself. + */ + private static readonly SHELL_INJECTION_REGEX = /!\{([^}]*)\}/g; + + /** + * @param commandName The name of the custom command being executed, used + * for logging and error messages. + */ + constructor(private readonly commandName: string) {} + + async process(prompt: string, context: CommandContext): Promise<string> { + const { config, sessionShellAllowlist } = { + ...context.services, + ...context.session, + }; + const commandsToExecute: Array<{ fullMatch: string; command: string }> = []; + const commandsToConfirm = new Set<string>(); + + const matches = [...prompt.matchAll(ShellProcessor.SHELL_INJECTION_REGEX)]; + if (matches.length === 0) { + return prompt; // No shell commands, nothing to do. + } + + // Discover all commands and check permissions. + for (const match of matches) { + const command = match[1].trim(); + const { allAllowed, disallowedCommands, blockReason, isHardDenial } = + checkCommandPermissions(command, config!, sessionShellAllowlist); + + if (!allAllowed) { + // If it's a hard denial, this is a non-recoverable security error. + if (isHardDenial) { + throw new Error( + `${this.commandName} cannot be run. ${blockReason || 'A shell command in this custom command is explicitly blocked in your config settings.'}`, + ); + } + + // Add each soft denial disallowed command to the set for confirmation. + disallowedCommands.forEach((uc) => commandsToConfirm.add(uc)); + } + commandsToExecute.push({ fullMatch: match[0], command }); + } + + // If any commands require confirmation, throw a special error to halt the + // pipeline and trigger the UI flow. + if (commandsToConfirm.size > 0) { + throw new ConfirmationRequiredError( + 'Shell command confirmation required', + Array.from(commandsToConfirm), + ); + } + + // Execute all commands (only runs if no confirmation was needed). + let processedPrompt = prompt; + for (const { fullMatch, command } of commandsToExecute) { + const { result } = ShellExecutionService.execute( + command, + config!.getTargetDir(), + () => {}, // No streaming needed. + new AbortController().signal, // For now, we don't support cancellation from here. + ); + + const executionResult = await result; + processedPrompt = processedPrompt.replace( + fullMatch, + executionResult.output, + ); + } + + return processedPrompt; + } +} diff --git a/packages/cli/src/services/prompt-processors/types.ts b/packages/cli/src/services/prompt-processors/types.ts index 2ca61062..2653d2b7 100644 --- a/packages/cli/src/services/prompt-processors/types.ts +++ b/packages/cli/src/services/prompt-processors/types.ts @@ -35,3 +35,8 @@ export interface IPromptProcessor { * The placeholder string for shorthand argument injection in custom commands. */ export const SHORTHAND_ARGS_PLACEHOLDER = '{{args}}'; + +/** + * The trigger string for shell command injection in custom commands. + */ +export const SHELL_INJECTION_TRIGGER = '!{'; diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx index 7b5aa8d0..87a78ac6 100644 --- a/packages/cli/src/ui/App.tsx +++ b/packages/cli/src/ui/App.tsx @@ -36,6 +36,7 @@ import { ThemeDialog } from './components/ThemeDialog.js'; import { AuthDialog } from './components/AuthDialog.js'; import { AuthInProgress } from './components/AuthInProgress.js'; import { EditorSettingsDialog } from './components/EditorSettingsDialog.js'; +import { ShellConfirmationDialog } from './components/ShellConfirmationDialog.js'; import { Colors } from './colors.js'; import { Help } from './components/Help.js'; import { loadHierarchicalGeminiMemory } from '../config/config.js'; @@ -169,6 +170,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { useState<boolean>(false); const [userTier, setUserTier] = useState<UserTierId | undefined>(undefined); const [openFiles, setOpenFiles] = useState<OpenFiles | undefined>(); + const [isProcessing, setIsProcessing] = useState<boolean>(false); useEffect(() => { const unsubscribe = ideContext.subscribeToOpenFiles(setOpenFiles); @@ -452,6 +454,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { slashCommands, pendingHistoryItems: pendingSlashCommandHistoryItems, commandContext, + shellConfirmationRequest, } = useSlashCommandProcessor( config, settings, @@ -468,6 +471,7 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { setQuittingMessages, openPrivacyNotice, toggleVimEnabled, + setIsProcessing, ); const { @@ -624,7 +628,8 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { fetchUserMessages(); }, [history, logger]); - const isInputActive = streamingState === StreamingState.Idle && !initError; + const isInputActive = + streamingState === StreamingState.Idle && !initError && !isProcessing; const handleClearScreen = useCallback(() => { clearItems(); @@ -830,7 +835,9 @@ const App = ({ config, settings, startupWarnings = [], version }: AppProps) => { </Box> )} - {isThemeDialogOpen ? ( + {shellConfirmationRequest ? ( + <ShellConfirmationDialog request={shellConfirmationRequest} /> + ) : isThemeDialogOpen ? ( <Box flexDirection="column"> {themeError && ( <Box marginBottom={1}> diff --git a/packages/cli/src/ui/commands/types.ts b/packages/cli/src/ui/commands/types.ts index 59b0178c..2844177f 100644 --- a/packages/cli/src/ui/commands/types.ts +++ b/packages/cli/src/ui/commands/types.ts @@ -63,6 +63,8 @@ export interface CommandContext { // Session-specific data session: { stats: SessionStatsState; + /** A transient list of shell commands the user has approved for this session. */ + sessionShellAllowlist: Set<string>; }; } @@ -118,13 +120,28 @@ export interface SubmitPromptActionReturn { content: string; } +/** + * The return type for a command action that needs to pause and request + * confirmation for a set of shell commands before proceeding. + */ +export interface ConfirmShellCommandsActionReturn { + type: 'confirm_shell_commands'; + /** The list of shell commands that require user confirmation. */ + commandsToConfirm: string[]; + /** The original invocation context to be re-run after confirmation. */ + originalInvocation: { + raw: string; + }; +} + export type SlashCommandActionReturn = | ToolActionReturn | MessageActionReturn | QuitActionReturn | OpenDialogActionReturn | LoadHistoryActionReturn - | SubmitPromptActionReturn; + | SubmitPromptActionReturn + | ConfirmShellCommandsActionReturn; export enum CommandKind { BUILT_IN = 'built-in', diff --git a/packages/cli/src/ui/components/ShellConfirmationDialog.test.tsx b/packages/cli/src/ui/components/ShellConfirmationDialog.test.tsx new file mode 100644 index 00000000..35783d44 --- /dev/null +++ b/packages/cli/src/ui/components/ShellConfirmationDialog.test.tsx @@ -0,0 +1,45 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { render } from 'ink-testing-library'; +import { describe, it, expect, vi } from 'vitest'; +import { ShellConfirmationDialog } from './ShellConfirmationDialog.js'; + +describe('ShellConfirmationDialog', () => { + const onConfirm = vi.fn(); + + const request = { + commands: ['ls -la', 'echo "hello"'], + onConfirm, + }; + + it('renders correctly', () => { + const { lastFrame } = render(<ShellConfirmationDialog request={request} />); + expect(lastFrame()).toMatchSnapshot(); + }); + + it('calls onConfirm with ProceedOnce when "Yes, allow once" is selected', () => { + const { lastFrame } = render(<ShellConfirmationDialog request={request} />); + const select = lastFrame()!.toString(); + // Simulate selecting the first option + // This is a simplified way to test the selection + expect(select).toContain('Yes, allow once'); + }); + + it('calls onConfirm with ProceedAlways when "Yes, allow always for this session" is selected', () => { + const { lastFrame } = render(<ShellConfirmationDialog request={request} />); + const select = lastFrame()!.toString(); + // Simulate selecting the second option + expect(select).toContain('Yes, allow always for this session'); + }); + + it('calls onConfirm with Cancel when "No (esc)" is selected', () => { + const { lastFrame } = render(<ShellConfirmationDialog request={request} />); + const select = lastFrame()!.toString(); + // Simulate selecting the third option + expect(select).toContain('No (esc)'); + }); +}); diff --git a/packages/cli/src/ui/components/ShellConfirmationDialog.tsx b/packages/cli/src/ui/components/ShellConfirmationDialog.tsx new file mode 100644 index 00000000..ec137a6d --- /dev/null +++ b/packages/cli/src/ui/components/ShellConfirmationDialog.tsx @@ -0,0 +1,98 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ToolConfirmationOutcome } from '@google/gemini-cli-core'; +import { Box, Text, useInput } from 'ink'; +import React from 'react'; +import { Colors } from '../colors.js'; +import { + RadioButtonSelect, + RadioSelectItem, +} from './shared/RadioButtonSelect.js'; + +export interface ShellConfirmationRequest { + commands: string[]; + onConfirm: ( + outcome: ToolConfirmationOutcome, + approvedCommands?: string[], + ) => void; +} + +export interface ShellConfirmationDialogProps { + request: ShellConfirmationRequest; +} + +export const ShellConfirmationDialog: React.FC< + ShellConfirmationDialogProps +> = ({ request }) => { + const { commands, onConfirm } = request; + + useInput((_, key) => { + if (key.escape) { + onConfirm(ToolConfirmationOutcome.Cancel); + } + }); + + const handleSelect = (item: ToolConfirmationOutcome) => { + if (item === ToolConfirmationOutcome.Cancel) { + onConfirm(item); + } else { + // For both ProceedOnce and ProceedAlways, we approve all the + // commands that were requested. + onConfirm(item, commands); + } + }; + + const options: Array<RadioSelectItem<ToolConfirmationOutcome>> = [ + { + label: 'Yes, allow once', + value: ToolConfirmationOutcome.ProceedOnce, + }, + { + label: 'Yes, allow always for this session', + value: ToolConfirmationOutcome.ProceedAlways, + }, + { + label: 'No (esc)', + value: ToolConfirmationOutcome.Cancel, + }, + ]; + + return ( + <Box + flexDirection="column" + borderStyle="round" + borderColor={Colors.AccentYellow} + padding={1} + width="100%" + marginLeft={1} + > + <Box flexDirection="column" marginBottom={1}> + <Text bold>Shell Command Execution</Text> + <Text>A custom command wants to run the following shell commands:</Text> + <Box + flexDirection="column" + borderStyle="round" + borderColor={Colors.Gray} + paddingX={1} + marginTop={1} + > + {commands.map((cmd) => ( + <Text key={cmd} color={Colors.AccentCyan}> + {cmd} + </Text> + ))} + </Box> + </Box> + + <Box marginBottom={1}> + <Text>Do you want to proceed?</Text> + </Box> + + <RadioButtonSelect items={options} onSelect={handleSelect} isFocused /> + </Box> + ); +}; diff --git a/packages/cli/src/ui/components/__snapshots__/ShellConfirmationDialog.test.tsx.snap b/packages/cli/src/ui/components/__snapshots__/ShellConfirmationDialog.test.tsx.snap new file mode 100644 index 00000000..8c9ceb29 --- /dev/null +++ b/packages/cli/src/ui/components/__snapshots__/ShellConfirmationDialog.test.tsx.snap @@ -0,0 +1,21 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`ShellConfirmationDialog > renders correctly 1`] = ` +" ╭──────────────────────────────────────────────────────────────────────────────────────────────────╮ + │ │ + │ Shell Command Execution │ + │ A custom command wants to run the following shell commands: │ + │ │ + │ ╭──────────────────────────────────────────────────────────────────────────────────────────────╮ │ + │ │ ls -la │ │ + │ │ echo "hello" │ │ + │ ╰──────────────────────────────────────────────────────────────────────────────────────────────╯ │ + │ │ + │ Do you want to proceed? │ + │ │ + │ ● 1. Yes, allow once │ + │ 2. Yes, allow always for this session │ + │ 3. No (esc) │ + │ │ + ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯" +`; diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts index ac9b79ec..5b367cd4 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.test.ts @@ -42,8 +42,13 @@ vi.mock('../contexts/SessionContext.js', () => ({ import { act, renderHook, waitFor } from '@testing-library/react'; import { vi, describe, it, expect, beforeEach, type Mock } from 'vitest'; import { useSlashCommandProcessor } from './slashCommandProcessor.js'; -import { CommandKind, SlashCommand } from '../commands/types.js'; -import { Config } from '@google/gemini-cli-core'; +import { + CommandContext, + CommandKind, + ConfirmShellCommandsActionReturn, + SlashCommand, +} from '../commands/types.js'; +import { Config, ToolConfirmationOutcome } from '@google/gemini-cli-core'; import { LoadedSettings } from '../../config/settings.js'; import { MessageType } from '../types.js'; import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js'; @@ -90,6 +95,7 @@ describe('useSlashCommandProcessor', () => { builtinCommands: SlashCommand[] = [], fileCommands: SlashCommand[] = [], mcpCommands: SlashCommand[] = [], + setIsProcessing = vi.fn(), ) => { mockBuiltinLoadCommands.mockResolvedValue(Object.freeze(builtinCommands)); mockFileLoadCommands.mockResolvedValue(Object.freeze(fileCommands)); @@ -112,6 +118,7 @@ describe('useSlashCommandProcessor', () => { mockSetQuittingMessages, vi.fn(), // openPrivacyNotice vi.fn(), // toggleVimEnabled + setIsProcessing, ), ); @@ -275,6 +282,32 @@ describe('useSlashCommandProcessor', () => { 'with args', ); }); + + it('should set isProcessing to true during execution and false afterwards', async () => { + const mockSetIsProcessing = vi.fn(); + const command = createTestCommand({ + name: 'long-running', + action: () => new Promise((resolve) => setTimeout(resolve, 50)), + }); + + const result = setupProcessorHook([command], [], [], mockSetIsProcessing); + await waitFor(() => expect(result.current.slashCommands).toHaveLength(1)); + + const executionPromise = act(async () => { + await result.current.handleSlashCommand('/long-running'); + }); + + // It should be true immediately after starting + expect(mockSetIsProcessing).toHaveBeenCalledWith(true); + // It should not have been called with false yet + expect(mockSetIsProcessing).not.toHaveBeenCalledWith(false); + + await executionPromise; + + // After the promise resolves, it should be called with false + expect(mockSetIsProcessing).toHaveBeenCalledWith(false); + expect(mockSetIsProcessing).toHaveBeenCalledTimes(2); + }); }); describe('Action Result Handling', () => { @@ -417,6 +450,176 @@ describe('useSlashCommandProcessor', () => { }); }); + describe('Shell Command Confirmation Flow', () => { + // Use a generic vi.fn() for the action. We will change its behavior in each test. + const mockCommandAction = vi.fn(); + + const shellCommand = createTestCommand({ + name: 'shellcmd', + action: mockCommandAction, + }); + + beforeEach(() => { + // Reset the mock before each test + mockCommandAction.mockClear(); + + // Default behavior: request confirmation + mockCommandAction.mockResolvedValue({ + type: 'confirm_shell_commands', + commandsToConfirm: ['rm -rf /'], + originalInvocation: { raw: '/shellcmd' }, + } as ConfirmShellCommandsActionReturn); + }); + + it('should set confirmation request when action returns confirm_shell_commands', async () => { + const result = setupProcessorHook([shellCommand]); + await waitFor(() => expect(result.current.slashCommands).toHaveLength(1)); + + // This is intentionally not awaited, because the promise it returns + // will not resolve until the user responds to the confirmation. + act(() => { + result.current.handleSlashCommand('/shellcmd'); + }); + + // We now wait for the state to be updated with the request. + await waitFor(() => { + expect(result.current.shellConfirmationRequest).not.toBeNull(); + }); + + expect(result.current.shellConfirmationRequest?.commands).toEqual([ + 'rm -rf /', + ]); + }); + + it('should do nothing if user cancels confirmation', async () => { + const result = setupProcessorHook([shellCommand]); + await waitFor(() => expect(result.current.slashCommands).toHaveLength(1)); + + act(() => { + result.current.handleSlashCommand('/shellcmd'); + }); + + // Wait for the confirmation dialog to be set + await waitFor(() => { + expect(result.current.shellConfirmationRequest).not.toBeNull(); + }); + + const onConfirm = result.current.shellConfirmationRequest?.onConfirm; + expect(onConfirm).toBeDefined(); + + // Change the mock action's behavior for a potential second run. + // If the test is flawed, this will be called, and we can detect it. + mockCommandAction.mockResolvedValue({ + type: 'message', + messageType: 'info', + content: 'This should not be called', + }); + + await act(async () => { + onConfirm!(ToolConfirmationOutcome.Cancel, []); // Pass empty array for safety + }); + + expect(result.current.shellConfirmationRequest).toBeNull(); + // Verify the action was only called the initial time. + expect(mockCommandAction).toHaveBeenCalledTimes(1); + }); + + it('should re-run command with one-time allowlist on "Proceed Once"', async () => { + const result = setupProcessorHook([shellCommand]); + await waitFor(() => expect(result.current.slashCommands).toHaveLength(1)); + + act(() => { + result.current.handleSlashCommand('/shellcmd'); + }); + await waitFor(() => { + expect(result.current.shellConfirmationRequest).not.toBeNull(); + }); + + const onConfirm = result.current.shellConfirmationRequest?.onConfirm; + + // **Change the mock's behavior for the SECOND run.** + // This is the key to testing the outcome. + mockCommandAction.mockResolvedValue({ + type: 'message', + messageType: 'info', + content: 'Success!', + }); + + await act(async () => { + onConfirm!(ToolConfirmationOutcome.ProceedOnce, ['rm -rf /']); + }); + + expect(result.current.shellConfirmationRequest).toBeNull(); + + // The action should have been called twice (initial + re-run). + await waitFor(() => { + expect(mockCommandAction).toHaveBeenCalledTimes(2); + }); + + // We can inspect the context of the second call to ensure the one-time list was used. + const secondCallContext = mockCommandAction.mock + .calls[1][0] as CommandContext; + expect( + secondCallContext.session.sessionShellAllowlist.has('rm -rf /'), + ).toBe(true); + + // Verify the final success message was added. + expect(mockAddItem).toHaveBeenCalledWith( + { type: MessageType.INFO, text: 'Success!' }, + expect.any(Number), + ); + + // Verify the session-wide allowlist was NOT permanently updated. + // Re-render the hook by calling a no-op command to get the latest context. + await act(async () => { + result.current.handleSlashCommand('/no-op'); + }); + const finalContext = result.current.commandContext; + expect(finalContext.session.sessionShellAllowlist.size).toBe(0); + }); + + it('should re-run command and update session allowlist on "Proceed Always"', async () => { + const result = setupProcessorHook([shellCommand]); + await waitFor(() => expect(result.current.slashCommands).toHaveLength(1)); + + act(() => { + result.current.handleSlashCommand('/shellcmd'); + }); + await waitFor(() => { + expect(result.current.shellConfirmationRequest).not.toBeNull(); + }); + + const onConfirm = result.current.shellConfirmationRequest?.onConfirm; + mockCommandAction.mockResolvedValue({ + type: 'message', + messageType: 'info', + content: 'Success!', + }); + + await act(async () => { + onConfirm!(ToolConfirmationOutcome.ProceedAlways, ['rm -rf /']); + }); + + expect(result.current.shellConfirmationRequest).toBeNull(); + await waitFor(() => { + expect(mockCommandAction).toHaveBeenCalledTimes(2); + }); + + expect(mockAddItem).toHaveBeenCalledWith( + { type: MessageType.INFO, text: 'Success!' }, + expect.any(Number), + ); + + // Check that the session-wide allowlist WAS updated. + await waitFor(() => { + const finalContext = result.current.commandContext; + expect(finalContext.session.sessionShellAllowlist.has('rm -rf /')).toBe( + true, + ); + }); + }); + }); + describe('Command Parsing and Matching', () => { it('should be case-sensitive', async () => { const command = createTestCommand({ name: 'test' }); @@ -583,7 +786,7 @@ describe('useSlashCommandProcessor', () => { }); describe('Lifecycle', () => { - it('should abort command loading when the hook unmounts', async () => { + it('should abort command loading when the hook unmounts', () => { const abortSpy = vi.spyOn(AbortController.prototype, 'abort'); const { unmount } = renderHook(() => useSlashCommandProcessor( @@ -597,10 +800,11 @@ describe('useSlashCommandProcessor', () => { vi.fn(), // onDebugMessage vi.fn(), // openThemeDialog mockOpenAuthDialog, - vi.fn(), // openEditorDialog + vi.fn(), // openEditorDialog, vi.fn(), // toggleCorgiMode mockSetQuittingMessages, vi.fn(), // openPrivacyNotice + vi.fn(), // toggleVimEnabled ), ); diff --git a/packages/cli/src/ui/hooks/slashCommandProcessor.ts b/packages/cli/src/ui/hooks/slashCommandProcessor.ts index 46b49329..be32de11 100644 --- a/packages/cli/src/ui/hooks/slashCommandProcessor.ts +++ b/packages/cli/src/ui/hooks/slashCommandProcessor.ts @@ -9,7 +9,12 @@ import { type PartListUnion } from '@google/genai'; import process from 'node:process'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { useStateAndRef } from './useStateAndRef.js'; -import { Config, GitService, Logger } from '@google/gemini-cli-core'; +import { + Config, + GitService, + Logger, + ToolConfirmationOutcome, +} from '@google/gemini-cli-core'; import { useSessionStats } from '../contexts/SessionContext.js'; import { Message, @@ -44,9 +49,21 @@ export const useSlashCommandProcessor = ( setQuittingMessages: (message: HistoryItem[]) => void, openPrivacyNotice: () => void, toggleVimEnabled: () => Promise<boolean>, + setIsProcessing: (isProcessing: boolean) => void, ) => { const session = useSessionStats(); const [commands, setCommands] = useState<readonly SlashCommand[]>([]); + const [shellConfirmationRequest, setShellConfirmationRequest] = + useState<null | { + commands: string[]; + onConfirm: ( + outcome: ToolConfirmationOutcome, + approvedCommands?: string[], + ) => void; + }>(null); + const [sessionShellAllowlist, setSessionShellAllowlist] = useState( + new Set<string>(), + ); const gitService = useMemo(() => { if (!config?.getProjectRoot()) { return; @@ -144,6 +161,7 @@ export const useSlashCommandProcessor = ( }, session: { stats: session.stats, + sessionShellAllowlist, }, }), [ @@ -161,6 +179,7 @@ export const useSlashCommandProcessor = ( setPendingCompressionItem, toggleCorgiMode, toggleVimEnabled, + sessionShellAllowlist, ], ); @@ -189,69 +208,87 @@ export const useSlashCommandProcessor = ( const handleSlashCommand = useCallback( async ( rawQuery: PartListUnion, + oneTimeShellAllowlist?: Set<string>, ): Promise<SlashCommandProcessorResult | false> => { - if (typeof rawQuery !== 'string') { - return false; - } + setIsProcessing(true); + try { + if (typeof rawQuery !== 'string') { + return false; + } - const trimmed = rawQuery.trim(); - if (!trimmed.startsWith('/') && !trimmed.startsWith('?')) { - return false; - } + const trimmed = rawQuery.trim(); + if (!trimmed.startsWith('/') && !trimmed.startsWith('?')) { + return false; + } - const userMessageTimestamp = Date.now(); - addItem({ type: MessageType.USER, text: trimmed }, userMessageTimestamp); + const userMessageTimestamp = Date.now(); + addItem( + { type: MessageType.USER, text: trimmed }, + userMessageTimestamp, + ); - const parts = trimmed.substring(1).trim().split(/\s+/); - const commandPath = parts.filter((p) => p); // The parts of the command, e.g., ['memory', 'add'] + const parts = trimmed.substring(1).trim().split(/\s+/); + const commandPath = parts.filter((p) => p); // The parts of the command, e.g., ['memory', 'add'] - let currentCommands = commands; - let commandToExecute: SlashCommand | undefined; - let pathIndex = 0; + let currentCommands = commands; + let commandToExecute: SlashCommand | undefined; + let pathIndex = 0; - for (const part of commandPath) { - // TODO: For better performance and architectural clarity, this two-pass - // search could be replaced. A more optimal approach would be to - // pre-compute a single lookup map in `CommandService.ts` that resolves - // all name and alias conflicts during the initial loading phase. The - // processor would then perform a single, fast lookup on that map. + for (const part of commandPath) { + // TODO: For better performance and architectural clarity, this two-pass + // search could be replaced. A more optimal approach would be to + // pre-compute a single lookup map in `CommandService.ts` that resolves + // all name and alias conflicts during the initial loading phase. The + // processor would then perform a single, fast lookup on that map. - // First pass: check for an exact match on the primary command name. - let foundCommand = currentCommands.find((cmd) => cmd.name === part); + // First pass: check for an exact match on the primary command name. + let foundCommand = currentCommands.find((cmd) => cmd.name === part); - // Second pass: if no primary name matches, check for an alias. - if (!foundCommand) { - foundCommand = currentCommands.find((cmd) => - cmd.altNames?.includes(part), - ); - } + // Second pass: if no primary name matches, check for an alias. + if (!foundCommand) { + foundCommand = currentCommands.find((cmd) => + cmd.altNames?.includes(part), + ); + } - if (foundCommand) { - commandToExecute = foundCommand; - pathIndex++; - if (foundCommand.subCommands) { - currentCommands = foundCommand.subCommands; + if (foundCommand) { + commandToExecute = foundCommand; + pathIndex++; + if (foundCommand.subCommands) { + currentCommands = foundCommand.subCommands; + } else { + break; + } } else { break; } - } else { - break; } - } - if (commandToExecute) { - const args = parts.slice(pathIndex).join(' '); + if (commandToExecute) { + const args = parts.slice(pathIndex).join(' '); + + if (commandToExecute.action) { + const fullCommandContext: CommandContext = { + ...commandContext, + invocation: { + raw: trimmed, + name: commandToExecute.name, + args, + }, + }; + + // If a one-time list is provided for a "Proceed" action, temporarily + // augment the session allowlist for this single execution. + if (oneTimeShellAllowlist && oneTimeShellAllowlist.size > 0) { + fullCommandContext.session = { + ...fullCommandContext.session, + sessionShellAllowlist: new Set([ + ...fullCommandContext.session.sessionShellAllowlist, + ...oneTimeShellAllowlist, + ]), + }; + } - if (commandToExecute.action) { - const fullCommandContext: CommandContext = { - ...commandContext, - invocation: { - raw: trimmed, - name: commandToExecute.name, - args, - }, - }; - try { const result = await commandToExecute.action( fullCommandContext, args, @@ -323,6 +360,46 @@ export const useSlashCommandProcessor = ( type: 'submit_prompt', content: result.content, }; + case 'confirm_shell_commands': { + const { outcome, approvedCommands } = await new Promise<{ + outcome: ToolConfirmationOutcome; + approvedCommands?: string[]; + }>((resolve) => { + setShellConfirmationRequest({ + commands: result.commandsToConfirm, + onConfirm: ( + resolvedOutcome, + resolvedApprovedCommands, + ) => { + setShellConfirmationRequest(null); // Close the dialog + resolve({ + outcome: resolvedOutcome, + approvedCommands: resolvedApprovedCommands, + }); + }, + }); + }); + + if ( + outcome === ToolConfirmationOutcome.Cancel || + !approvedCommands || + approvedCommands.length === 0 + ) { + return { type: 'handled' }; + } + + if (outcome === ToolConfirmationOutcome.ProceedAlways) { + setSessionShellAllowlist( + (prev) => new Set([...prev, ...approvedCommands]), + ); + } + + return await handleSlashCommand( + result.originalInvocation.raw, + // Pass the approved commands as a one-time grant for this execution. + new Set(approvedCommands), + ); + } default: { const unhandled: never = result; throw new Error( @@ -331,37 +408,39 @@ export const useSlashCommandProcessor = ( } } } - } catch (e) { - addItem( - { - type: MessageType.ERROR, - text: e instanceof Error ? e.message : String(e), - }, - Date.now(), - ); + + return { type: 'handled' }; + } else if (commandToExecute.subCommands) { + const helpText = `Command '/${commandToExecute.name}' requires a subcommand. Available:\n${commandToExecute.subCommands + .map((sc) => ` - ${sc.name}: ${sc.description || ''}`) + .join('\n')}`; + addMessage({ + type: MessageType.INFO, + content: helpText, + timestamp: new Date(), + }); return { type: 'handled' }; } - - return { type: 'handled' }; - } else if (commandToExecute.subCommands) { - const helpText = `Command '/${commandToExecute.name}' requires a subcommand. Available:\n${commandToExecute.subCommands - .map((sc) => ` - ${sc.name}: ${sc.description || ''}`) - .join('\n')}`; - addMessage({ - type: MessageType.INFO, - content: helpText, - timestamp: new Date(), - }); - return { type: 'handled' }; } - } - addMessage({ - type: MessageType.ERROR, - content: `Unknown command: ${trimmed}`, - timestamp: new Date(), - }); - return { type: 'handled' }; + addMessage({ + type: MessageType.ERROR, + content: `Unknown command: ${trimmed}`, + timestamp: new Date(), + }); + return { type: 'handled' }; + } catch (e) { + addItem( + { + type: MessageType.ERROR, + text: e instanceof Error ? e.message : String(e), + }, + Date.now(), + ); + return { type: 'handled' }; + } finally { + setIsProcessing(false); + } }, [ config, @@ -375,6 +454,9 @@ export const useSlashCommandProcessor = ( openPrivacyNotice, openEditorDialog, setQuittingMessages, + setShellConfirmationRequest, + setSessionShellAllowlist, + setIsProcessing, ], ); @@ -383,5 +465,6 @@ export const useSlashCommandProcessor = ( slashCommands: commands, pendingHistoryItems, commandContext, + shellConfirmationRequest, }; }; |
