diff options
| author | Arya Gummadi <[email protected]> | 2025-08-19 18:22:41 -0700 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-08-20 01:22:41 +0000 |
| commit | 2a71c10b8a5e43336edf77d140e9d206521c55d1 (patch) | |
| tree | 68f1e5d3fe2bd2ce36f85c5bee42150e5265da6b | |
| parent | d587c6f1042824de7a1ae94eb1ea9c049cfc34c9 (diff) | |
feat: auto-approve compatible pending tools when allow always is selected (#6519)
| -rw-r--r-- | packages/core/src/core/coreToolScheduler.test.ts | 199 | ||||
| -rw-r--r-- | packages/core/src/core/coreToolScheduler.ts | 35 |
2 files changed, 234 insertions, 0 deletions
diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 6eb0e5b9..1c400d52 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -8,6 +8,7 @@ import { describe, it, expect, vi } from 'vitest'; import { CoreToolScheduler, ToolCall, + WaitingToolCall, convertToFunctionResponse, } from './coreToolScheduler.js'; import { @@ -26,6 +27,77 @@ import { import { Part, PartListUnion } from '@google/genai'; import { MockModifiableTool, MockTool } from '../test-utils/tools.js'; +class TestApprovalTool extends BaseDeclarativeTool<{ id: string }, ToolResult> { + static readonly Name = 'testApprovalTool'; + + constructor(private config: Config) { + super( + TestApprovalTool.Name, + 'TestApprovalTool', + 'A tool for testing approval logic', + Kind.Edit, + { + properties: { id: { type: 'string' } }, + required: ['id'], + type: 'object', + }, + ); + } + + protected createInvocation(params: { + id: string; + }): ToolInvocation<{ id: string }, ToolResult> { + return new TestApprovalInvocation(this.config, params); + } +} + +class TestApprovalInvocation extends BaseToolInvocation< + { id: string }, + ToolResult +> { + constructor( + private config: Config, + params: { id: string }, + ) { + super(params); + } + + getDescription(): string { + return `Test tool ${this.params.id}`; + } + + override async shouldConfirmExecute(): Promise< + ToolCallConfirmationDetails | false + > { + // Need confirmation unless approval mode is AUTO_EDIT + if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) { + return false; + } + + return { + type: 'edit', + title: `Confirm Test Tool ${this.params.id}`, + fileName: `test-${this.params.id}.txt`, + filePath: `/test-${this.params.id}.txt`, + fileDiff: 'Test diff content', + originalContent: '', + newContent: 'Test content', + onConfirm: async (outcome: ToolConfirmationOutcome) => { + if (outcome === ToolConfirmationOutcome.ProceedAlways) { + this.config.setApprovalMode(ApprovalMode.AUTO_EDIT); + } + }, + }; + } + + async execute(): Promise<ToolResult> { + return { + llmContent: `Executed test tool ${this.params.id}`, + returnDisplay: `Executed test tool ${this.params.id}`, + }; + } +} + describe('CoreToolScheduler', () => { it('should cancel a tool call if the signal is aborted before confirmation', async () => { const mockTool = new MockTool(); @@ -759,4 +831,131 @@ describe('CoreToolScheduler request queueing', () => { // Ensure completion callbacks were called twice. expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2); }); + + it('should auto-approve remaining tool calls when first tool call is approved with ProceedAlways', async () => { + let approvalMode = ApprovalMode.DEFAULT; + const mockConfig = { + getSessionId: () => 'test-session-id', + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + getApprovalMode: () => approvalMode, + setApprovalMode: (mode: ApprovalMode) => { + approvalMode = mode; + }, + } as unknown as Config; + + const testTool = new TestApprovalTool(mockConfig); + const toolRegistry = { + getTool: () => testTool, + getFunctionDeclarations: () => [], + getFunctionDeclarationsFiltered: () => [], + registerTool: () => {}, + discoverAllTools: async () => {}, + discoverMcpTools: async () => {}, + discoverToolsForServer: async () => {}, + removeMcpToolsByServer: () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + tools: new Map(), + config: mockConfig, + mcpClientManager: undefined, + getToolByName: () => testTool, + getToolByDisplayName: () => testTool, + getTools: () => [], + discoverTools: async () => {}, + discovery: {}, + }; + + const onAllToolCallsComplete = vi.fn(); + const onToolCallsUpdate = vi.fn(); + const pendingConfirmations: Array< + (outcome: ToolConfirmationOutcome) => void + > = []; + + const scheduler = new CoreToolScheduler({ + config: mockConfig, + toolRegistry: toolRegistry as unknown as ToolRegistry, + onAllToolCallsComplete, + onToolCallsUpdate: (toolCalls) => { + onToolCallsUpdate(toolCalls); + // Capture confirmation handlers for awaiting_approval tools + toolCalls.forEach((call) => { + if (call.status === 'awaiting_approval') { + const waitingCall = call as WaitingToolCall; + if (waitingCall.confirmationDetails?.onConfirm) { + const originalHandler = pendingConfirmations.find( + (h) => h === waitingCall.confirmationDetails.onConfirm, + ); + if (!originalHandler) { + pendingConfirmations.push( + waitingCall.confirmationDetails.onConfirm, + ); + } + } + } + }); + }, + getPreferredEditor: () => 'vscode', + onEditorClose: vi.fn(), + }); + + const abortController = new AbortController(); + + // Schedule multiple tools that need confirmation + const requests = [ + { + callId: '1', + name: 'testApprovalTool', + args: { id: 'first' }, + isClientInitiated: false, + prompt_id: 'prompt-1', + }, + { + callId: '2', + name: 'testApprovalTool', + args: { id: 'second' }, + isClientInitiated: false, + prompt_id: 'prompt-2', + }, + { + callId: '3', + name: 'testApprovalTool', + args: { id: 'third' }, + isClientInitiated: false, + prompt_id: 'prompt-3', + }, + ]; + + await scheduler.schedule(requests, abortController.signal); + + // Wait for all tools to be awaiting approval + await vi.waitFor(() => { + const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[]; + expect(calls?.length).toBe(3); + expect(calls?.every((call) => call.status === 'awaiting_approval')).toBe( + true, + ); + }); + + expect(pendingConfirmations.length).toBe(3); + + // Approve the first tool with ProceedAlways + const firstConfirmation = pendingConfirmations[0]; + firstConfirmation(ToolConfirmationOutcome.ProceedAlways); + + // Wait for all tools to be completed + await vi.waitFor(() => { + expect(onAllToolCallsComplete).toHaveBeenCalled(); + const completedCalls = onAllToolCallsComplete.mock.calls.at( + -1, + )?.[0] as ToolCall[]; + expect(completedCalls?.length).toBe(3); + expect(completedCalls?.every((call) => call.status === 'success')).toBe( + true, + ); + }); + + // Verify approval mode was changed + expect(approvalMode).toBe(ApprovalMode.AUTO_EDIT); + }); }); diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 7a1698c9..5a2bb85d 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -695,6 +695,10 @@ export class CoreToolScheduler { await originalOnConfirm(outcome); } + if (outcome === ToolConfirmationOutcome.ProceedAlways) { + await this.autoApproveCompatiblePendingTools(signal, callId); + } + this.setToolCallOutcome(callId, outcome); if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) { @@ -928,4 +932,35 @@ export class CoreToolScheduler { }; }); } + + private async autoApproveCompatiblePendingTools( + signal: AbortSignal, + triggeringCallId: string, + ): Promise<void> { + const pendingTools = this.toolCalls.filter( + (call) => + call.status === 'awaiting_approval' && + call.request.callId !== triggeringCallId, + ) as WaitingToolCall[]; + + for (const pendingTool of pendingTools) { + try { + const stillNeedsConfirmation = + await pendingTool.invocation.shouldConfirmExecute(signal); + + if (!stillNeedsConfirmation) { + this.setToolCallOutcome( + pendingTool.request.callId, + ToolConfirmationOutcome.ProceedAlways, + ); + this.setStatusInternal(pendingTool.request.callId, 'scheduled'); + } + } catch (error) { + console.error( + `Error checking confirmation for tool ${pendingTool.request.callId}:`, + error, + ); + } + } + } } |
