summaryrefslogtreecommitdiff
path: root/packages/core/src
diff options
context:
space:
mode:
authorArya Gummadi <[email protected]>2025-08-19 18:22:41 -0700
committerGitHub <[email protected]>2025-08-20 01:22:41 +0000
commit2a71c10b8a5e43336edf77d140e9d206521c55d1 (patch)
tree68f1e5d3fe2bd2ce36f85c5bee42150e5265da6b /packages/core/src
parentd587c6f1042824de7a1ae94eb1ea9c049cfc34c9 (diff)
feat: auto-approve compatible pending tools when allow always is selected (#6519)
Diffstat (limited to 'packages/core/src')
-rw-r--r--packages/core/src/core/coreToolScheduler.test.ts199
-rw-r--r--packages/core/src/core/coreToolScheduler.ts35
2 files changed, 234 insertions, 0 deletions
diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts
index 6eb0e5b9..1c400d52 100644
--- a/packages/core/src/core/coreToolScheduler.test.ts
+++ b/packages/core/src/core/coreToolScheduler.test.ts
@@ -8,6 +8,7 @@ import { describe, it, expect, vi } from 'vitest';
import {
CoreToolScheduler,
ToolCall,
+ WaitingToolCall,
convertToFunctionResponse,
} from './coreToolScheduler.js';
import {
@@ -26,6 +27,77 @@ import {
import { Part, PartListUnion } from '@google/genai';
import { MockModifiableTool, MockTool } from '../test-utils/tools.js';
+class TestApprovalTool extends BaseDeclarativeTool<{ id: string }, ToolResult> {
+ static readonly Name = 'testApprovalTool';
+
+ constructor(private config: Config) {
+ super(
+ TestApprovalTool.Name,
+ 'TestApprovalTool',
+ 'A tool for testing approval logic',
+ Kind.Edit,
+ {
+ properties: { id: { type: 'string' } },
+ required: ['id'],
+ type: 'object',
+ },
+ );
+ }
+
+ protected createInvocation(params: {
+ id: string;
+ }): ToolInvocation<{ id: string }, ToolResult> {
+ return new TestApprovalInvocation(this.config, params);
+ }
+}
+
+class TestApprovalInvocation extends BaseToolInvocation<
+ { id: string },
+ ToolResult
+> {
+ constructor(
+ private config: Config,
+ params: { id: string },
+ ) {
+ super(params);
+ }
+
+ getDescription(): string {
+ return `Test tool ${this.params.id}`;
+ }
+
+ override async shouldConfirmExecute(): Promise<
+ ToolCallConfirmationDetails | false
+ > {
+ // Need confirmation unless approval mode is AUTO_EDIT
+ if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
+ return false;
+ }
+
+ return {
+ type: 'edit',
+ title: `Confirm Test Tool ${this.params.id}`,
+ fileName: `test-${this.params.id}.txt`,
+ filePath: `/test-${this.params.id}.txt`,
+ fileDiff: 'Test diff content',
+ originalContent: '',
+ newContent: 'Test content',
+ onConfirm: async (outcome: ToolConfirmationOutcome) => {
+ if (outcome === ToolConfirmationOutcome.ProceedAlways) {
+ this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
+ }
+ },
+ };
+ }
+
+ async execute(): Promise<ToolResult> {
+ return {
+ llmContent: `Executed test tool ${this.params.id}`,
+ returnDisplay: `Executed test tool ${this.params.id}`,
+ };
+ }
+}
+
describe('CoreToolScheduler', () => {
it('should cancel a tool call if the signal is aborted before confirmation', async () => {
const mockTool = new MockTool();
@@ -759,4 +831,131 @@ describe('CoreToolScheduler request queueing', () => {
// Ensure completion callbacks were called twice.
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(2);
});
+
+ it('should auto-approve remaining tool calls when first tool call is approved with ProceedAlways', async () => {
+ let approvalMode = ApprovalMode.DEFAULT;
+ const mockConfig = {
+ getSessionId: () => 'test-session-id',
+ getUsageStatisticsEnabled: () => true,
+ getDebugMode: () => false,
+ getApprovalMode: () => approvalMode,
+ setApprovalMode: (mode: ApprovalMode) => {
+ approvalMode = mode;
+ },
+ } as unknown as Config;
+
+ const testTool = new TestApprovalTool(mockConfig);
+ const toolRegistry = {
+ getTool: () => testTool,
+ getFunctionDeclarations: () => [],
+ getFunctionDeclarationsFiltered: () => [],
+ registerTool: () => {},
+ discoverAllTools: async () => {},
+ discoverMcpTools: async () => {},
+ discoverToolsForServer: async () => {},
+ removeMcpToolsByServer: () => {},
+ getAllTools: () => [],
+ getToolsByServer: () => [],
+ tools: new Map(),
+ config: mockConfig,
+ mcpClientManager: undefined,
+ getToolByName: () => testTool,
+ getToolByDisplayName: () => testTool,
+ getTools: () => [],
+ discoverTools: async () => {},
+ discovery: {},
+ };
+
+ const onAllToolCallsComplete = vi.fn();
+ const onToolCallsUpdate = vi.fn();
+ const pendingConfirmations: Array<
+ (outcome: ToolConfirmationOutcome) => void
+ > = [];
+
+ const scheduler = new CoreToolScheduler({
+ config: mockConfig,
+ toolRegistry: toolRegistry as unknown as ToolRegistry,
+ onAllToolCallsComplete,
+ onToolCallsUpdate: (toolCalls) => {
+ onToolCallsUpdate(toolCalls);
+ // Capture confirmation handlers for awaiting_approval tools
+ toolCalls.forEach((call) => {
+ if (call.status === 'awaiting_approval') {
+ const waitingCall = call as WaitingToolCall;
+ if (waitingCall.confirmationDetails?.onConfirm) {
+ const originalHandler = pendingConfirmations.find(
+ (h) => h === waitingCall.confirmationDetails.onConfirm,
+ );
+ if (!originalHandler) {
+ pendingConfirmations.push(
+ waitingCall.confirmationDetails.onConfirm,
+ );
+ }
+ }
+ }
+ });
+ },
+ getPreferredEditor: () => 'vscode',
+ onEditorClose: vi.fn(),
+ });
+
+ const abortController = new AbortController();
+
+ // Schedule multiple tools that need confirmation
+ const requests = [
+ {
+ callId: '1',
+ name: 'testApprovalTool',
+ args: { id: 'first' },
+ isClientInitiated: false,
+ prompt_id: 'prompt-1',
+ },
+ {
+ callId: '2',
+ name: 'testApprovalTool',
+ args: { id: 'second' },
+ isClientInitiated: false,
+ prompt_id: 'prompt-2',
+ },
+ {
+ callId: '3',
+ name: 'testApprovalTool',
+ args: { id: 'third' },
+ isClientInitiated: false,
+ prompt_id: 'prompt-3',
+ },
+ ];
+
+ await scheduler.schedule(requests, abortController.signal);
+
+ // Wait for all tools to be awaiting approval
+ await vi.waitFor(() => {
+ const calls = onToolCallsUpdate.mock.calls.at(-1)?.[0] as ToolCall[];
+ expect(calls?.length).toBe(3);
+ expect(calls?.every((call) => call.status === 'awaiting_approval')).toBe(
+ true,
+ );
+ });
+
+ expect(pendingConfirmations.length).toBe(3);
+
+ // Approve the first tool with ProceedAlways
+ const firstConfirmation = pendingConfirmations[0];
+ firstConfirmation(ToolConfirmationOutcome.ProceedAlways);
+
+ // Wait for all tools to be completed
+ await vi.waitFor(() => {
+ expect(onAllToolCallsComplete).toHaveBeenCalled();
+ const completedCalls = onAllToolCallsComplete.mock.calls.at(
+ -1,
+ )?.[0] as ToolCall[];
+ expect(completedCalls?.length).toBe(3);
+ expect(completedCalls?.every((call) => call.status === 'success')).toBe(
+ true,
+ );
+ });
+
+ // Verify approval mode was changed
+ expect(approvalMode).toBe(ApprovalMode.AUTO_EDIT);
+ });
});
diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts
index 7a1698c9..5a2bb85d 100644
--- a/packages/core/src/core/coreToolScheduler.ts
+++ b/packages/core/src/core/coreToolScheduler.ts
@@ -695,6 +695,10 @@ export class CoreToolScheduler {
await originalOnConfirm(outcome);
}
+ if (outcome === ToolConfirmationOutcome.ProceedAlways) {
+ await this.autoApproveCompatiblePendingTools(signal, callId);
+ }
+
this.setToolCallOutcome(callId, outcome);
if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) {
@@ -928,4 +932,35 @@ export class CoreToolScheduler {
};
});
}
+
+ private async autoApproveCompatiblePendingTools(
+ signal: AbortSignal,
+ triggeringCallId: string,
+ ): Promise<void> {
+ const pendingTools = this.toolCalls.filter(
+ (call) =>
+ call.status === 'awaiting_approval' &&
+ call.request.callId !== triggeringCallId,
+ ) as WaitingToolCall[];
+
+ for (const pendingTool of pendingTools) {
+ try {
+ const stillNeedsConfirmation =
+ await pendingTool.invocation.shouldConfirmExecute(signal);
+
+ if (!stillNeedsConfirmation) {
+ this.setToolCallOutcome(
+ pendingTool.request.callId,
+ ToolConfirmationOutcome.ProceedAlways,
+ );
+ this.setStatusInternal(pendingTool.request.callId, 'scheduled');
+ }
+ } catch (error) {
+ console.error(
+ `Error checking confirmation for tool ${pendingTool.request.callId}:`,
+ error,
+ );
+ }
+ }
+ }
}