diff options
Diffstat (limited to 'packages/cli/src/ui/hooks/useReactToolScheduler.ts')
| -rw-r--r-- | packages/cli/src/ui/hooks/useReactToolScheduler.ts | 301 |
1 files changed, 301 insertions, 0 deletions
diff --git a/packages/cli/src/ui/hooks/useReactToolScheduler.ts b/packages/cli/src/ui/hooks/useReactToolScheduler.ts new file mode 100644 index 00000000..12333d92 --- /dev/null +++ b/packages/cli/src/ui/hooks/useReactToolScheduler.ts @@ -0,0 +1,301 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + Config, + ToolCallRequestInfo, + ExecutingToolCall, + ScheduledToolCall, + ValidatingToolCall, + WaitingToolCall, + CompletedToolCall, + CancelledToolCall, + CoreToolScheduler, + OutputUpdateHandler, + AllToolCallsCompleteHandler, + ToolCallsUpdateHandler, + Tool, + ToolCall, + Status as CoreStatus, +} from '@gemini-code/core'; +import { useCallback, useEffect, useState, useRef } from 'react'; +import { + HistoryItemToolGroup, + IndividualToolCallDisplay, + ToolCallStatus, + HistoryItemWithoutId, +} from '../types.js'; + +export type ScheduleFn = ( + request: ToolCallRequestInfo | ToolCallRequestInfo[], +) => void; +export type CancelFn = (reason?: string) => void; +export type MarkToolsAsSubmittedFn = (callIds: string[]) => void; + +export type TrackedScheduledToolCall = ScheduledToolCall & { + responseSubmittedToGemini?: boolean; +}; +export type TrackedValidatingToolCall = ValidatingToolCall & { + responseSubmittedToGemini?: boolean; +}; +export type TrackedWaitingToolCall = WaitingToolCall & { + responseSubmittedToGemini?: boolean; +}; +export type TrackedExecutingToolCall = ExecutingToolCall & { + responseSubmittedToGemini?: boolean; +}; +export type TrackedCompletedToolCall = CompletedToolCall & { + responseSubmittedToGemini?: boolean; +}; +export type TrackedCancelledToolCall = CancelledToolCall & { + responseSubmittedToGemini?: boolean; +}; + +export type TrackedToolCall = + | TrackedScheduledToolCall + | TrackedValidatingToolCall + | TrackedWaitingToolCall + | TrackedExecutingToolCall + | TrackedCompletedToolCall + | TrackedCancelledToolCall; + +export function useReactToolScheduler( + onComplete: (tools: CompletedToolCall[]) => void, + config: Config, + setPendingHistoryItem: React.Dispatch< + React.SetStateAction<HistoryItemWithoutId | null> + >, +): [TrackedToolCall[], ScheduleFn, CancelFn, MarkToolsAsSubmittedFn] { + const [toolCallsForDisplay, setToolCallsForDisplay] = useState< + TrackedToolCall[] + >([]); + const schedulerRef = useRef<CoreToolScheduler | null>(null); + + useEffect(() => { + const outputUpdateHandler: OutputUpdateHandler = ( + toolCallId, + outputChunk, + ) => { + setPendingHistoryItem((prevItem) => { + if (prevItem?.type === 'tool_group') { + return { + ...prevItem, + tools: prevItem.tools.map((toolDisplay) => + toolDisplay.callId === toolCallId && + toolDisplay.status === ToolCallStatus.Executing + ? { ...toolDisplay, resultDisplay: outputChunk } + : toolDisplay, + ), + }; + } + return prevItem; + }); + + setToolCallsForDisplay((prevCalls) => + prevCalls.map((tc) => { + if (tc.request.callId === toolCallId && tc.status === 'executing') { + const executingTc = tc as TrackedExecutingToolCall; + return { ...executingTc, liveOutput: outputChunk }; + } + return tc; + }), + ); + }; + + const allToolCallsCompleteHandler: AllToolCallsCompleteHandler = ( + completedToolCalls, + ) => { + onComplete(completedToolCalls); + }; + + const toolCallsUpdateHandler: ToolCallsUpdateHandler = ( + updatedCoreToolCalls: ToolCall[], + ) => { + setToolCallsForDisplay((prevTrackedCalls) => + updatedCoreToolCalls.map((coreTc) => { + const existingTrackedCall = prevTrackedCalls.find( + (ptc) => ptc.request.callId === coreTc.request.callId, + ); + const newTrackedCall: TrackedToolCall = { + ...coreTc, + responseSubmittedToGemini: + existingTrackedCall?.responseSubmittedToGemini ?? false, + } as TrackedToolCall; + return newTrackedCall; + }), + ); + }; + + schedulerRef.current = new CoreToolScheduler({ + toolRegistry: config.getToolRegistry(), + outputUpdateHandler, + onAllToolCallsComplete: allToolCallsCompleteHandler, + onToolCallsUpdate: toolCallsUpdateHandler, + }); + }, [config, onComplete, setPendingHistoryItem]); + + const schedule: ScheduleFn = useCallback( + async (request: ToolCallRequestInfo | ToolCallRequestInfo[]) => { + schedulerRef.current?.schedule(request); + }, + [], + ); + + const cancel: CancelFn = useCallback((reason: string = 'unspecified') => { + schedulerRef.current?.cancelAll(reason); + }, []); + + const markToolsAsSubmitted: MarkToolsAsSubmittedFn = useCallback( + (callIdsToMark: string[]) => { + setToolCallsForDisplay((prevCalls) => + prevCalls.map((tc) => + callIdsToMark.includes(tc.request.callId) + ? { ...tc, responseSubmittedToGemini: true } + : tc, + ), + ); + }, + [], + ); + + return [toolCallsForDisplay, schedule, cancel, markToolsAsSubmitted]; +} + +/** + * Maps a CoreToolScheduler status to the UI's ToolCallStatus enum. + */ +function mapCoreStatusToDisplayStatus(coreStatus: CoreStatus): ToolCallStatus { + switch (coreStatus) { + case 'validating': + return ToolCallStatus.Executing; + case 'awaiting_approval': + return ToolCallStatus.Confirming; + case 'executing': + return ToolCallStatus.Executing; + case 'success': + return ToolCallStatus.Success; + case 'cancelled': + return ToolCallStatus.Canceled; + case 'error': + return ToolCallStatus.Error; + case 'scheduled': + return ToolCallStatus.Pending; + default: { + const exhaustiveCheck: never = coreStatus; + console.warn(`Unknown core status encountered: ${exhaustiveCheck}`); + return ToolCallStatus.Error; + } + } +} + +/** + * Transforms `TrackedToolCall` objects into `HistoryItemToolGroup` objects for UI display. + */ +export function mapToDisplay( + toolOrTools: TrackedToolCall[] | TrackedToolCall, +): HistoryItemToolGroup { + const toolCalls = Array.isArray(toolOrTools) ? toolOrTools : [toolOrTools]; + + const toolDisplays = toolCalls.map( + (trackedCall): IndividualToolCallDisplay => { + let displayName = trackedCall.request.name; + let description = ''; + let renderOutputAsMarkdown = false; + + const currentToolInstance = + 'tool' in trackedCall && trackedCall.tool + ? (trackedCall as { tool: Tool }).tool + : undefined; + + if (currentToolInstance) { + displayName = currentToolInstance.displayName; + description = currentToolInstance.getDescription( + trackedCall.request.args, + ); + renderOutputAsMarkdown = currentToolInstance.isOutputMarkdown; + } + + if (trackedCall.status === 'error') { + description = ''; + } + + const baseDisplayProperties: Omit< + IndividualToolCallDisplay, + 'status' | 'resultDisplay' | 'confirmationDetails' + > = { + callId: trackedCall.request.callId, + name: displayName, + description, + renderOutputAsMarkdown, + }; + + switch (trackedCall.status) { + case 'success': + return { + ...baseDisplayProperties, + status: mapCoreStatusToDisplayStatus(trackedCall.status), + resultDisplay: trackedCall.response.resultDisplay, + confirmationDetails: undefined, + }; + case 'error': + return { + ...baseDisplayProperties, + name: currentToolInstance?.displayName ?? trackedCall.request.name, + status: mapCoreStatusToDisplayStatus(trackedCall.status), + resultDisplay: trackedCall.response.resultDisplay, + confirmationDetails: undefined, + }; + case 'cancelled': + return { + ...baseDisplayProperties, + status: mapCoreStatusToDisplayStatus(trackedCall.status), + resultDisplay: trackedCall.response.resultDisplay, + confirmationDetails: undefined, + }; + case 'awaiting_approval': + return { + ...baseDisplayProperties, + status: mapCoreStatusToDisplayStatus(trackedCall.status), + resultDisplay: undefined, + confirmationDetails: trackedCall.confirmationDetails, + }; + case 'executing': + return { + ...baseDisplayProperties, + status: mapCoreStatusToDisplayStatus(trackedCall.status), + resultDisplay: + (trackedCall as TrackedExecutingToolCall).liveOutput ?? undefined, + confirmationDetails: undefined, + }; + case 'validating': // Fallthrough + case 'scheduled': + return { + ...baseDisplayProperties, + status: mapCoreStatusToDisplayStatus(trackedCall.status), + resultDisplay: undefined, + confirmationDetails: undefined, + }; + default: { + const exhaustiveCheck: never = trackedCall; + return { + callId: (exhaustiveCheck as TrackedToolCall).request.callId, + name: 'Unknown Tool', + description: 'Encountered an unknown tool call state.', + status: ToolCallStatus.Error, + resultDisplay: 'Unknown tool call state', + confirmationDetails: undefined, + renderOutputAsMarkdown: false, + }; + } + } + }, + ); + + return { + type: 'tool_group', + tools: toolDisplays, + }; +} |
