diff options
Diffstat (limited to 'packages/cli/src/ui/hooks/useGeminiStream.ts')
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 311 |
1 files changed, 80 insertions, 231 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index de7980d5..324a4ffa 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -4,34 +4,28 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useState, useRef, useCallback, useEffect } from 'react'; +import { useState, useRef, useCallback, useEffect, useMemo } from 'react'; import { useInput } from 'ink'; import { GeminiClient, GeminiEventType as ServerGeminiEventType, ServerGeminiStreamEvent as GeminiEvent, ServerGeminiContentEvent as ContentEvent, - ServerGeminiToolCallRequestEvent as ToolCallRequestEvent, - ServerGeminiToolCallResponseEvent as ToolCallResponseEvent, - ServerGeminiToolCallConfirmationEvent as ToolCallConfirmationEvent, ServerGeminiErrorEvent as ErrorEvent, getErrorMessage, isNodeError, Config, MessageSenderType, ServerToolCallConfirmationDetails, - ToolCallConfirmationDetails, ToolCallResponseInfo, - ToolConfirmationOutcome, ToolEditConfirmationDetails, ToolExecuteConfirmationDetails, ToolResultDisplay, - partListUnionToString, + ToolCallRequestInfo, } from '@gemini-code/server'; import { type Chat, type PartListUnion, type Part } from '@google/genai'; import { StreamingState, - IndividualToolCallDisplay, ToolCallStatus, HistoryItemWithoutId, HistoryItemToolGroup, @@ -44,6 +38,7 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js'; import { useStateAndRef } from './useStateAndRef.js'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { useLogger } from './useLogger.js'; +import { useToolScheduler, mapToDisplay } from './useToolScheduler.js'; enum StreamProcessingStatus { Completed, @@ -65,7 +60,6 @@ export const useGeminiStream = ( handleSlashCommand: (cmd: PartListUnion) => boolean, shellModeActive: boolean, ) => { - const toolRegistry = config.getToolRegistry(); const [initError, setInitError] = useState<string | null>(null); const abortControllerRef = useRef<AbortController | null>(null); const chatSessionRef = useRef<Chat | null>(null); @@ -74,6 +68,25 @@ export const useGeminiStream = ( const [pendingHistoryItemRef, setPendingHistoryItem] = useStateAndRef<HistoryItemWithoutId | null>(null); const logger = useLogger(); + const [toolCalls, schedule, cancel] = useToolScheduler((tools) => { + if (tools.length) { + addItem(mapToDisplay(tools), Date.now()); + submitQuery( + tools + .filter( + (t) => + t.status === 'error' || + t.status === 'cancelled' || + t.status === 'success', + ) + .map((t) => t.response.responsePart), + ); + } + }, config); + const pendingToolCalls = useMemo( + () => (toolCalls.length ? mapToDisplay(toolCalls) : undefined), + [toolCalls], + ); const onExec = useCallback(async (done: Promise<void>) => { setIsResponding(true); @@ -104,6 +117,7 @@ export const useGeminiStream = ( useInput((_input, key) => { if (streamingState !== StreamingState.Idle && key.escape) { abortControllerRef.current?.abort(); + cancel(); } }); @@ -215,157 +229,48 @@ export const useGeminiStream = ( ); }; - const updateConfirmingFunctionStatusUI = ( - callId: string, - confirmationDetails: ToolCallConfirmationDetails | undefined, - ) => { - setPendingHistoryItem((item) => - item?.type === 'tool_group' - ? { - ...item, - tools: item.tools.map((tool) => - tool.callId === callId - ? { - ...tool, - status: ToolCallStatus.Confirming, - confirmationDetails, - } - : tool, - ), - } - : item, - ); - }; - - const wireConfirmationSubmission = ( - confirmationDetails: ServerToolCallConfirmationDetails, - ): ToolCallConfirmationDetails => { - const originalConfirmationDetails = confirmationDetails.details; - const request = confirmationDetails.request; - const resubmittingConfirm = async (outcome: ToolConfirmationOutcome) => { - originalConfirmationDetails.onConfirm(outcome); - if (pendingHistoryItemRef?.current?.type === 'tool_group') { - setPendingHistoryItem((item) => - item?.type === 'tool_group' - ? { - ...item, - tools: item.tools.map((tool) => - tool.callId === request.callId - ? { - ...tool, - confirmationDetails: undefined, - status: ToolCallStatus.Executing, - } - : tool, - ), - } - : item, - ); - refreshStatic(); - } - - if (outcome === ToolConfirmationOutcome.Cancel) { - declineToolExecution( - 'User rejected function call.', - ToolCallStatus.Error, - request, - originalConfirmationDetails, - ); - } else { - const tool = toolRegistry.getTool(request.name); - if (!tool) { - throw new Error( - `Tool "${request.name}" not found or is not registered.`, - ); - } - try { - abortControllerRef.current = new AbortController(); - const result = await tool.execute( - request.args, - abortControllerRef.current.signal, - ); - if (abortControllerRef.current.signal.aborted) { - declineToolExecution( - partListUnionToString(result.llmContent), - ToolCallStatus.Canceled, - request, - originalConfirmationDetails, - ); - return; - } - - const functionResponse: Part = { - functionResponse: { - name: request.name, - id: request.callId, - response: { output: result.llmContent }, - }, - }; - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responsePart: functionResponse, - resultDisplay: result.returnDisplay, - error: undefined, - }; - updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, Date.now()); - setPendingHistoryItem(null); - } - setIsResponding(false); - await submitQuery(functionResponse); // Recursive call - } finally { - if (streamingState !== StreamingState.WaitingForConfirmation) { - abortControllerRef.current = null; - } - } - } - }; - - // Extracted declineToolExecution to be part of wireConfirmationSubmission's closure - // or could be a standalone helper if more params are passed. - function declineToolExecution( - declineMessage: string, - status: ToolCallStatus, - request: ServerToolCallConfirmationDetails['request'], - originalDetails: ServerToolCallConfirmationDetails['details'], - ) { - let resultDisplay: ToolResultDisplay | undefined; - if ('fileDiff' in originalDetails) { - resultDisplay = { - fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff, - fileName: (originalDetails as ToolEditConfirmationDetails).fileName, - }; - } else { - resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`; - } - const functionResponse: Part = { - functionResponse: { - id: request.callId, - name: request.name, - response: { error: declineMessage }, - }, - }; - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responsePart: functionResponse, - resultDisplay, - error: new Error(declineMessage), + // Extracted declineToolExecution to be part of wireConfirmationSubmission's closure + // or could be a standalone helper if more params are passed. + // TODO: handle file diff result display stuff + function _declineToolExecution( + declineMessage: string, + status: ToolCallStatus, + request: ServerToolCallConfirmationDetails['request'], + originalDetails: ServerToolCallConfirmationDetails['details'], + ) { + let resultDisplay: ToolResultDisplay | undefined; + if ('fileDiff' in originalDetails) { + resultDisplay = { + fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff, + fileName: (originalDetails as ToolEditConfirmationDetails).fileName, }; - const history = chatSessionRef.current?.getHistory(); - if (history) { - history.push({ role: 'model', parts: [functionResponse] }); - } - updateFunctionResponseUI(responseInfo, status); - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, Date.now()); - setPendingHistoryItem(null); - } - setIsResponding(false); + } else { + resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`; } - - return { ...originalConfirmationDetails, onConfirm: resubmittingConfirm }; - }; + const functionResponse: Part = { + functionResponse: { + id: request.callId, + name: request.name, + response: { error: declineMessage }, + }, + }; + const responseInfo: ToolCallResponseInfo = { + callId: request.callId, + responsePart: functionResponse, + resultDisplay, + error: new Error(declineMessage), + }; + const history = chatSessionRef.current?.getHistory(); + if (history) { + history.push({ role: 'model', parts: [functionResponse] }); + } + updateFunctionResponseUI(responseInfo, status); + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, Date.now()); + setPendingHistoryItem(null); + } + setIsResponding(false); + } // --- Stream Event Handlers --- const handleContentEvent = ( @@ -419,62 +324,6 @@ export const useGeminiStream = ( return newGeminiMessageBuffer; }; - const handleToolCallRequestEvent = ( - eventValue: ToolCallRequestEvent['value'], - userMessageTimestamp: number, - ) => { - const { callId, name, args } = eventValue; - const cliTool = toolRegistry.getTool(name); - if (!cliTool) { - console.error(`CLI Tool "${name}" not found!`); - return; // Skip this event if tool is not found - } - if (pendingHistoryItemRef.current?.type !== 'tool_group') { - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); - } - setPendingHistoryItem({ type: 'tool_group', tools: [] }); - } - let description: string; - try { - description = cliTool.getDescription(args); - } catch (e) { - description = `Error: Unable to get description: ${getErrorMessage(e)}`; - } - const toolCallDisplay: IndividualToolCallDisplay = { - callId, - name: cliTool.displayName, - description, - status: ToolCallStatus.Pending, - resultDisplay: undefined, - confirmationDetails: undefined, - }; - setPendingHistoryItem((pending) => - pending?.type === 'tool_group' - ? { ...pending, tools: [...pending.tools, toolCallDisplay] } - : null, - ); - }; - - const handleToolCallResponseEvent = ( - eventValue: ToolCallResponseEvent['value'], - ) => { - const status = eventValue.error - ? ToolCallStatus.Error - : ToolCallStatus.Success; - updateFunctionResponseUI(eventValue, status); - }; - - const handleToolCallConfirmationEvent = ( - eventValue: ToolCallConfirmationEvent['value'], - ) => { - const confirmationDetails = wireConfirmationSubmission(eventValue); - updateConfirmingFunctionStatusUI( - eventValue.request.callId, - confirmationDetails, - ); - }; - const handleUserCancelledEvent = (userMessageTimestamp: number) => { if (pendingHistoryItemRef.current) { if (pendingHistoryItemRef.current.type === 'tool_group') { @@ -500,6 +349,7 @@ export const useGeminiStream = ( userMessageTimestamp, ); setIsResponding(false); + cancel(); }; const handleErrorEvent = ( @@ -521,7 +371,7 @@ export const useGeminiStream = ( userMessageTimestamp: number, ): Promise<StreamProcessingStatus> => { let geminiMessageBuffer = ''; - + const toolCallRequests: ToolCallRequestInfo[] = []; for await (const event of stream) { if (event.type === ServerGeminiEventType.Content) { geminiMessageBuffer = handleContentEvent( @@ -530,12 +380,7 @@ export const useGeminiStream = ( userMessageTimestamp, ); } else if (event.type === ServerGeminiEventType.ToolCallRequest) { - handleToolCallRequestEvent(event.value, userMessageTimestamp); - } else if (event.type === ServerGeminiEventType.ToolCallResponse) { - handleToolCallResponseEvent(event.value); - } else if (event.type === ServerGeminiEventType.ToolCallConfirmation) { - handleToolCallConfirmationEvent(event.value); - return StreamProcessingStatus.PausedForConfirmation; + toolCallRequests.push(event.value); } else if (event.type === ServerGeminiEventType.UserCancelled) { handleUserCancelledEvent(userMessageTimestamp); return StreamProcessingStatus.UserCancelled; @@ -544,9 +389,18 @@ export const useGeminiStream = ( return StreamProcessingStatus.Error; } } + schedule(toolCallRequests); return StreamProcessingStatus.Completed; }; + const streamingState: StreamingState = isResponding + ? StreamingState.Responding + : pendingToolCalls?.tools.some( + (t) => t.status === ToolCallStatus.Confirming, + ) + ? StreamingState.WaitingForConfirmation + : StreamingState.Idle; + const submitQuery = useCallback( async (query: PartListUnion) => { if (isResponding) return; @@ -625,20 +479,15 @@ export const useGeminiStream = ( ], ); - const streamingState: StreamingState = isResponding - ? StreamingState.Responding - : pendingConfirmations(pendingHistoryItemRef.current) - ? StreamingState.WaitingForConfirmation - : StreamingState.Idle; + const pendingHistoryItems = [ + pendingHistoryItemRef.current, + pendingToolCalls, + ].filter((i) => i !== undefined && i !== null); return { streamingState, submitQuery, initError, - pendingHistoryItem: pendingHistoryItemRef.current, + pendingHistoryItems, }; }; - -const pendingConfirmations = (item: HistoryItemWithoutId | null): boolean => - item?.type === 'tool_group' && - item.tools.some((t) => t.status === ToolCallStatus.Confirming); |
