diff options
Diffstat (limited to 'packages/core/src')
| -rw-r--r-- | packages/core/src/core/coreToolScheduler.ts | 520 | ||||
| -rw-r--r-- | packages/core/src/index.ts | 3 | ||||
| -rw-r--r-- | packages/core/src/tools/tools.ts | 2 |
3 files changed, 522 insertions, 3 deletions
diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts new file mode 100644 index 00000000..1278d468 --- /dev/null +++ b/packages/core/src/core/coreToolScheduler.ts @@ -0,0 +1,520 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + ToolCallRequestInfo, + ToolCallResponseInfo, + ToolConfirmationOutcome, + Tool, + ToolCallConfirmationDetails, + ToolResult, + ToolRegistry, +} from '../index.js'; +import { Part, PartUnion, PartListUnion } from '@google/genai'; + +export type ValidatingToolCall = { + status: 'validating'; + request: ToolCallRequestInfo; + tool: Tool; +}; + +export type ScheduledToolCall = { + status: 'scheduled'; + request: ToolCallRequestInfo; + tool: Tool; +}; + +export type ErroredToolCall = { + status: 'error'; + request: ToolCallRequestInfo; + response: ToolCallResponseInfo; +}; + +export type SuccessfulToolCall = { + status: 'success'; + request: ToolCallRequestInfo; + tool: Tool; + response: ToolCallResponseInfo; +}; + +export type ExecutingToolCall = { + status: 'executing'; + request: ToolCallRequestInfo; + tool: Tool; + liveOutput?: string; +}; + +export type CancelledToolCall = { + status: 'cancelled'; + request: ToolCallRequestInfo; + response: ToolCallResponseInfo; + tool: Tool; +}; + +export type WaitingToolCall = { + status: 'awaiting_approval'; + request: ToolCallRequestInfo; + tool: Tool; + confirmationDetails: ToolCallConfirmationDetails; +}; + +export type Status = ToolCall['status']; + +export type ToolCall = + | ValidatingToolCall + | ScheduledToolCall + | ErroredToolCall + | SuccessfulToolCall + | ExecutingToolCall + | CancelledToolCall + | WaitingToolCall; + +export type CompletedToolCall = + | SuccessfulToolCall + | CancelledToolCall + | ErroredToolCall; + +export type ConfirmHandler = ( + toolCall: WaitingToolCall, +) => Promise<ToolConfirmationOutcome>; + +export type OutputUpdateHandler = ( + toolCallId: string, + outputChunk: string, +) => void; + +export type AllToolCallsCompleteHandler = ( + completedToolCalls: CompletedToolCall[], +) => void; + +export type ToolCallsUpdateHandler = (toolCalls: ToolCall[]) => void; + +/** + * Formats tool output for a Gemini FunctionResponse. + */ +export function formatLlmContentForFunctionResponse( + llmContent: PartListUnion, +): { + functionResponseJson: Record<string, string>; + additionalParts: PartUnion[]; +} { + const additionalParts: PartUnion[] = []; + let functionResponseJson: Record<string, string>; + + const contentToProcess = + Array.isArray(llmContent) && llmContent.length === 1 + ? llmContent[0] + : llmContent; + + if (typeof contentToProcess === 'string') { + functionResponseJson = { output: contentToProcess }; + } else if (Array.isArray(contentToProcess)) { + functionResponseJson = { + status: 'Tool execution succeeded.', + }; + additionalParts.push(...contentToProcess); + } else if (contentToProcess.inlineData || contentToProcess.fileData) { + const mimeType = + contentToProcess.inlineData?.mimeType || + contentToProcess.fileData?.mimeType || + 'unknown'; + functionResponseJson = { + status: `Binary content of type ${mimeType} was processed.`, + }; + additionalParts.push(contentToProcess); + } else if (contentToProcess.text !== undefined) { + functionResponseJson = { output: contentToProcess.text }; + } else { + functionResponseJson = { status: 'Tool execution succeeded.' }; + additionalParts.push(contentToProcess); + } + + return { + functionResponseJson, + additionalParts, + }; +} + +const createErrorResponse = ( + request: ToolCallRequestInfo, + error: Error, +): ToolCallResponseInfo => ({ + callId: request.callId, + error, + responseParts: { + functionResponse: { + id: request.callId, + name: request.name, + response: { error: error.message }, + }, + }, + resultDisplay: error.message, +}); + +interface CoreToolSchedulerOptions { + toolRegistry: ToolRegistry; + outputUpdateHandler?: OutputUpdateHandler; + onAllToolCallsComplete?: AllToolCallsCompleteHandler; + onToolCallsUpdate?: ToolCallsUpdateHandler; +} + +export class CoreToolScheduler { + private toolRegistry: ToolRegistry; + private toolCalls: ToolCall[] = []; + private abortController: AbortController; + private outputUpdateHandler?: OutputUpdateHandler; + private onAllToolCallsComplete?: AllToolCallsCompleteHandler; + private onToolCallsUpdate?: ToolCallsUpdateHandler; + + constructor(options: CoreToolSchedulerOptions) { + this.toolRegistry = options.toolRegistry; + this.outputUpdateHandler = options.outputUpdateHandler; + this.onAllToolCallsComplete = options.onAllToolCallsComplete; + this.onToolCallsUpdate = options.onToolCallsUpdate; + this.abortController = new AbortController(); + } + + private setStatusInternal( + targetCallId: string, + status: 'success', + response: ToolCallResponseInfo, + ): void; + private setStatusInternal( + targetCallId: string, + status: 'awaiting_approval', + confirmationDetails: ToolCallConfirmationDetails, + ): void; + private setStatusInternal( + targetCallId: string, + status: 'error', + response: ToolCallResponseInfo, + ): void; + private setStatusInternal( + targetCallId: string, + status: 'cancelled', + reason: string, + ): void; + private setStatusInternal( + targetCallId: string, + status: 'executing' | 'scheduled' | 'validating', + ): void; + private setStatusInternal( + targetCallId: string, + newStatus: Status, + auxiliaryData?: unknown, + ): void { + this.toolCalls = this.toolCalls.map((currentCall) => { + if ( + currentCall.request.callId !== targetCallId || + currentCall.status === 'error' + ) { + return currentCall; + } + + const callWithToolContext = currentCall as ToolCall & { tool: Tool }; + + switch (newStatus) { + case 'success': + return { + ...callWithToolContext, + status: 'success', + response: auxiliaryData as ToolCallResponseInfo, + } as SuccessfulToolCall; + case 'error': + return { + request: currentCall.request, + status: 'error', + response: auxiliaryData as ToolCallResponseInfo, + } as ErroredToolCall; + case 'awaiting_approval': + return { + ...callWithToolContext, + status: 'awaiting_approval', + confirmationDetails: auxiliaryData as ToolCallConfirmationDetails, + } as WaitingToolCall; + case 'scheduled': + return { + ...callWithToolContext, + status: 'scheduled', + } as ScheduledToolCall; + case 'cancelled': + return { + ...callWithToolContext, + status: 'cancelled', + response: { + callId: currentCall.request.callId, + responseParts: { + functionResponse: { + id: currentCall.request.callId, + name: currentCall.request.name, + response: { + error: `[Operation Cancelled] Reason: ${auxiliaryData}`, + }, + }, + }, + resultDisplay: undefined, + error: undefined, + }, + } as CancelledToolCall; + case 'validating': + return { + ...(currentCall as ValidatingToolCall), + status: 'validating', + } as ValidatingToolCall; + case 'executing': + return { + ...callWithToolContext, + status: 'executing', + } as ExecutingToolCall; + default: { + const exhaustiveCheck: never = newStatus; + return exhaustiveCheck; + } + } + }); + this.notifyToolCallsUpdate(); + this.checkAndNotifyCompletion(); + } + + private isRunning(): boolean { + return this.toolCalls.some( + (call) => + call.status === 'executing' || call.status === 'awaiting_approval', + ); + } + + async schedule( + request: ToolCallRequestInfo | ToolCallRequestInfo[], + ): Promise<void> { + if (this.isRunning()) { + throw new Error( + 'Cannot schedule new tool calls while other tool calls are actively running (executing or awaiting approval).', + ); + } + const requestsToProcess = Array.isArray(request) ? request : [request]; + + const newToolCalls: ToolCall[] = requestsToProcess.map( + (reqInfo): ToolCall => { + const toolInstance = this.toolRegistry.getTool(reqInfo.name); + if (!toolInstance) { + return { + status: 'error', + request: reqInfo, + response: createErrorResponse( + reqInfo, + new Error(`Tool "${reqInfo.name}" not found in registry.`), + ), + }; + } + return { status: 'validating', request: reqInfo, tool: toolInstance }; + }, + ); + + this.toolCalls = this.toolCalls.concat(newToolCalls); + this.notifyToolCallsUpdate(); + + for (const toolCall of newToolCalls) { + if (toolCall.status !== 'validating') { + continue; + } + + const { request: reqInfo, tool: toolInstance } = toolCall; + try { + const confirmationDetails = await toolInstance.shouldConfirmExecute( + reqInfo.args, + this.abortController.signal, + ); + + if (confirmationDetails) { + const originalOnConfirm = confirmationDetails.onConfirm; + const wrappedConfirmationDetails: ToolCallConfirmationDetails = { + ...confirmationDetails, + onConfirm: (outcome: ToolConfirmationOutcome) => + this.handleConfirmationResponse( + reqInfo.callId, + originalOnConfirm, + outcome, + ), + }; + this.setStatusInternal( + reqInfo.callId, + 'awaiting_approval', + wrappedConfirmationDetails, + ); + } else { + this.setStatusInternal(reqInfo.callId, 'scheduled'); + } + } catch (error) { + this.setStatusInternal( + reqInfo.callId, + 'error', + createErrorResponse( + reqInfo, + error instanceof Error ? error : new Error(String(error)), + ), + ); + } + } + this.attemptExecutionOfScheduledCalls(); + this.checkAndNotifyCompletion(); + } + + async handleConfirmationResponse( + callId: string, + originalOnConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>, + outcome: ToolConfirmationOutcome, + ): Promise<void> { + const toolCall = this.toolCalls.find( + (c) => c.request.callId === callId && c.status === 'awaiting_approval', + ); + + if (toolCall && toolCall.status === 'awaiting_approval') { + await originalOnConfirm(outcome); + } + + if (outcome === ToolConfirmationOutcome.Cancel) { + this.setStatusInternal( + callId, + 'cancelled', + 'User did not allow tool call', + ); + } else { + this.setStatusInternal(callId, 'scheduled'); + } + this.attemptExecutionOfScheduledCalls(); + } + + private attemptExecutionOfScheduledCalls(): void { + const allCallsFinalOrScheduled = this.toolCalls.every( + (call) => + call.status === 'scheduled' || + call.status === 'cancelled' || + call.status === 'success' || + call.status === 'error', + ); + + if (allCallsFinalOrScheduled) { + const callsToExecute = this.toolCalls.filter( + (call) => call.status === 'scheduled', + ); + + callsToExecute.forEach((toolCall) => { + if (toolCall.status !== 'scheduled') return; + + const scheduledCall = toolCall as ScheduledToolCall; + const { callId, name: toolName } = scheduledCall.request; + this.setStatusInternal(callId, 'executing'); + + const liveOutputCallback = + scheduledCall.tool.canUpdateOutput && this.outputUpdateHandler + ? (outputChunk: string) => { + if (this.outputUpdateHandler) { + this.outputUpdateHandler(callId, outputChunk); + } + this.toolCalls = this.toolCalls.map((tc) => + tc.request.callId === callId && tc.status === 'executing' + ? { ...(tc as ExecutingToolCall), liveOutput: outputChunk } + : tc, + ); + this.notifyToolCallsUpdate(); + } + : undefined; + + scheduledCall.tool + .execute( + scheduledCall.request.args, + this.abortController.signal, + liveOutputCallback, + ) + .then((toolResult: ToolResult) => { + if (this.abortController.signal.aborted) { + this.setStatusInternal( + callId, + 'cancelled', + this.abortController.signal.reason || 'Execution aborted.', + ); + return; + } + + const { functionResponseJson, additionalParts } = + formatLlmContentForFunctionResponse(toolResult.llmContent); + + const functionResponsePart: Part = { + functionResponse: { + name: toolName, + id: callId, + response: functionResponseJson, + }, + }; + + const successResponse: ToolCallResponseInfo = { + callId, + responseParts: [functionResponsePart, ...additionalParts], + resultDisplay: toolResult.returnDisplay, + error: undefined, + }; + this.setStatusInternal(callId, 'success', successResponse); + }) + .catch((executionError: Error) => { + this.setStatusInternal( + callId, + 'error', + createErrorResponse( + scheduledCall.request, + executionError instanceof Error + ? executionError + : new Error(String(executionError)), + ), + ); + }); + }); + } + } + + private checkAndNotifyCompletion(): void { + const allCallsAreTerminal = this.toolCalls.every( + (call) => + call.status === 'success' || + call.status === 'error' || + call.status === 'cancelled', + ); + + if (this.toolCalls.length > 0 && allCallsAreTerminal) { + const completedCalls = [...this.toolCalls] as CompletedToolCall[]; + this.toolCalls = []; + + if (this.onAllToolCallsComplete) { + this.onAllToolCallsComplete(completedCalls); + } + this.abortController = new AbortController(); + this.notifyToolCallsUpdate(); + } + } + + cancelAll(reason: string = 'User initiated cancellation.'): void { + if (!this.abortController.signal.aborted) { + this.abortController.abort(reason); + } + this.abortController = new AbortController(); + + const callsToCancel = [...this.toolCalls]; + callsToCancel.forEach((call) => { + if ( + call.status !== 'error' && + call.status !== 'success' && + call.status !== 'cancelled' + ) { + this.setStatusInternal(call.request.callId, 'cancelled', reason); + } + }); + } + + private notifyToolCallsUpdate(): void { + if (this.onToolCallsUpdate) { + this.onToolCallsUpdate([...this.toolCalls]); + } + } +} diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 70426d57..f8c42336 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -13,8 +13,7 @@ export * from './core/logger.js'; export * from './core/prompts.js'; export * from './core/turn.js'; export * from './core/geminiRequest.js'; -// Potentially export types from turn.ts if needed externally -// export { GeminiEventType } from './core/turn.js'; // Example +export * from './core/coreToolScheduler.js'; // Export utilities export * from './utils/paths.js'; diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index a2e7fa06..1b932229 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -218,7 +218,7 @@ export interface ToolMcpConfirmationDetails { serverName: string; toolName: string; toolDisplayName: string; - onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void> | void; + onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>; } export type ToolCallConfirmationDetails = |
