diff options
Diffstat (limited to 'packages/core/src')
| -rw-r--r-- | packages/core/src/core/coreToolScheduler.test.ts | 105 | ||||
| -rw-r--r-- | packages/core/src/core/coreToolScheduler.ts | 46 | ||||
| -rw-r--r-- | packages/core/src/tools/shell.ts | 7 |
3 files changed, 122 insertions, 36 deletions
diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index be42bb24..1e8b2b2a 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -4,9 +4,110 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect } from 'vitest'; -import { convertToFunctionResponse } from './coreToolScheduler.js'; +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { describe, it, expect, vi } from 'vitest'; +import { + CoreToolScheduler, + ToolCall, + ValidatingToolCall, +} from './coreToolScheduler.js'; +import { + BaseTool, + ToolCallConfirmationDetails, + ToolConfirmationOutcome, + ToolResult, +} from '../index.js'; import { Part, PartListUnion } from '@google/genai'; +import { convertToFunctionResponse } from './coreToolScheduler.js'; + +class MockTool extends BaseTool<Record<string, unknown>, ToolResult> { + shouldConfirm = false; + executeFn = vi.fn(); + + constructor(name = 'mockTool') { + super(name, name, 'A mock tool', {}); + } + + async shouldConfirmExecute( + _params: Record<string, unknown>, + _abortSignal: AbortSignal, + ): Promise<ToolCallConfirmationDetails | false> { + if (this.shouldConfirm) { + return { + type: 'exec', + title: 'Confirm Mock Tool', + command: 'do_thing', + rootCommand: 'do_thing', + onConfirm: async () => {}, + }; + } + return false; + } + + async execute( + params: Record<string, unknown>, + _abortSignal: AbortSignal, + ): Promise<ToolResult> { + this.executeFn(params); + return { llmContent: 'Tool executed', returnDisplay: 'Tool executed' }; + } +} + +describe('CoreToolScheduler', () => { + it('should cancel a tool call if the signal is aborted before confirmation', async () => { + const mockTool = new MockTool(); + mockTool.shouldConfirm = true; + const toolRegistry = { + getTool: () => mockTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {} as any, + config: {} as any, + registerTool: () => {}, + getToolByName: () => mockTool, + getToolByDisplayName: () => mockTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + }; + + const onAllToolCallsComplete = vi.fn(); + const onToolCallsUpdate = vi.fn(); + + const scheduler = new CoreToolScheduler({ + toolRegistry: Promise.resolve(toolRegistry as any), + onAllToolCallsComplete, + onToolCallsUpdate, + }); + + const abortController = new AbortController(); + const request = { callId: '1', name: 'mockTool', args: {} }; + + abortController.abort(); + await scheduler.schedule([request], abortController.signal); + + const _waitingCall = onToolCallsUpdate.mock + .calls[1][0][0] as ValidatingToolCall; + const confirmationDetails = await mockTool.shouldConfirmExecute( + {}, + abortController.signal, + ); + if (confirmationDetails) { + await scheduler.handleConfirmationResponse( + '1', + confirmationDetails.onConfirm, + ToolConfirmationOutcome.ProceedOnce, + abortController.signal, + ); + } + + expect(onAllToolCallsComplete).toHaveBeenCalled(); + const completedCalls = onAllToolCallsComplete.mock + .calls[0][0] as ToolCall[]; + expect(completedCalls[0].status).toBe('cancelled'); + }); +}); describe('convertToFunctionResponse', () => { const toolName = 'testTool'; diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 5278ae76..679d86aa 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -208,7 +208,6 @@ interface CoreToolSchedulerOptions { export class CoreToolScheduler { private toolRegistry: Promise<ToolRegistry>; private toolCalls: ToolCall[] = []; - private abortController: AbortController; private outputUpdateHandler?: OutputUpdateHandler; private onAllToolCallsComplete?: AllToolCallsCompleteHandler; private onToolCallsUpdate?: ToolCallsUpdateHandler; @@ -220,7 +219,6 @@ export class CoreToolScheduler { this.onAllToolCallsComplete = options.onAllToolCallsComplete; this.onToolCallsUpdate = options.onToolCallsUpdate; this.approvalMode = options.approvalMode ?? ApprovalMode.DEFAULT; - this.abortController = new AbortController(); } private setStatusInternal( @@ -379,6 +377,7 @@ export class CoreToolScheduler { async schedule( request: ToolCallRequestInfo | ToolCallRequestInfo[], + signal: AbortSignal, ): Promise<void> { if (this.isRunning()) { throw new Error( @@ -426,7 +425,7 @@ export class CoreToolScheduler { } else { const confirmationDetails = await toolInstance.shouldConfirmExecute( reqInfo.args, - this.abortController.signal, + signal, ); if (confirmationDetails) { @@ -438,6 +437,7 @@ export class CoreToolScheduler { reqInfo.callId, originalOnConfirm, outcome, + signal, ), }; this.setStatusInternal( @@ -460,7 +460,7 @@ export class CoreToolScheduler { ); } } - this.attemptExecutionOfScheduledCalls(); + this.attemptExecutionOfScheduledCalls(signal); this.checkAndNotifyCompletion(); } @@ -468,6 +468,7 @@ export class CoreToolScheduler { callId: string, originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>, outcome: ToolConfirmationOutcome, + signal: AbortSignal, ): Promise<void> { const toolCall = this.toolCalls.find( (c) => c.request.callId === callId && c.status === 'awaiting_approval', @@ -477,7 +478,7 @@ export class CoreToolScheduler { await originalOnConfirm(outcome); } - if (outcome === ToolConfirmationOutcome.Cancel) { + if (outcome === ToolConfirmationOutcome.Cancel || signal.aborted) { this.setStatusInternal( callId, 'cancelled', @@ -497,7 +498,7 @@ export class CoreToolScheduler { const modifyResults = await editTool.onModify( waitingToolCall.request.args as unknown as EditToolParams, - this.abortController.signal, + signal, outcome, ); @@ -513,10 +514,10 @@ export class CoreToolScheduler { } else { this.setStatusInternal(callId, 'scheduled'); } - this.attemptExecutionOfScheduledCalls(); + this.attemptExecutionOfScheduledCalls(signal); } - private attemptExecutionOfScheduledCalls(): void { + private attemptExecutionOfScheduledCalls(signal: AbortSignal): void { const allCallsFinalOrScheduled = this.toolCalls.every( (call) => call.status === 'scheduled' || @@ -553,17 +554,13 @@ export class CoreToolScheduler { : undefined; scheduledCall.tool - .execute( - scheduledCall.request.args, - this.abortController.signal, - liveOutputCallback, - ) + .execute(scheduledCall.request.args, signal, liveOutputCallback) .then((toolResult: ToolResult) => { - if (this.abortController.signal.aborted) { + if (signal.aborted) { this.setStatusInternal( callId, 'cancelled', - this.abortController.signal.reason || 'Execution aborted.', + 'User cancelled tool execution.', ); return; } @@ -613,29 +610,10 @@ export class CoreToolScheduler { if (this.onAllToolCallsComplete) { this.onAllToolCallsComplete(completedCalls); } - this.abortController = new AbortController(); this.notifyToolCallsUpdate(); } } - cancelAll(reason: string = 'User initiated cancellation.'): void { - if (!this.abortController.signal.aborted) { - this.abortController.abort(reason); - } - this.abortController = new AbortController(); - - const callsToCancel = [...this.toolCalls]; - callsToCancel.forEach((call) => { - if ( - call.status !== 'error' && - call.status !== 'success' && - call.status !== 'cancelled' - ) { - this.setStatusInternal(call.request.callId, 'cancelled', reason); - } - }); - } - private notifyToolCallsUpdate(): void { if (this.onToolCallsUpdate) { this.onToolCallsUpdate([...this.toolCalls]); diff --git a/packages/core/src/tools/shell.ts b/packages/core/src/tools/shell.ts index 9ced00a4..caef67b9 100644 --- a/packages/core/src/tools/shell.ts +++ b/packages/core/src/tools/shell.ts @@ -162,6 +162,13 @@ export class ShellTool extends BaseTool<ShellToolParams, ToolResult> { }; } + if (abortSignal.aborted) { + return { + llmContent: 'Command was cancelled by user before it could start.', + returnDisplay: 'Command cancelled by user.', + }; + } + // wrap command to append subprocess pids (via pgrep) to temporary file const tempFileName = `shell_pgrep_${crypto.randomBytes(6).toString('hex')}.tmp`; const tempFilePath = path.join(os.tmpdir(), tempFileName); |
