diff options
| author | Brandon Keiji <[email protected]> | 2025-05-14 22:14:15 +0000 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-05-14 15:14:15 -0700 |
| commit | 521708e294fe7eb16683d137e46aa36be4b9ddf5 (patch) | |
| tree | 0bbbd0a3d2cdbce9a00cd5f7bd84ace1f22b39ed /packages/cli/src/ui/hooks | |
| parent | 1245fe488510975b774816138e4597603851415f (diff) | |
refactor: break submitQuery into smaller functions (#350)
Diffstat (limited to 'packages/cli/src/ui/hooks')
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 881 |
1 files changed, 487 insertions, 394 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 2b18f0a1..035f3e85 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -9,6 +9,12 @@ import { useInput } from 'ink'; import { GeminiClient, GeminiEventType as ServerGeminiEventType, + ServerGeminiStreamEvent as GeminiEvent, + ServerGeminiContentEvent as ContentEvent, + ServerGeminiToolCallRequestEvent as ToolCallRequestEvent, + ServerGeminiToolCallResponseEvent as ToolCallResponseEvent, + ServerGeminiToolCallConfirmationEvent as ToolCallConfirmationEvent, + ServerGeminiErrorEvent as ErrorEvent, getErrorMessage, isNodeError, Config, @@ -26,6 +32,7 @@ import { IndividualToolCallDisplay, ToolCallStatus, HistoryItemWithoutId, + HistoryItemToolGroup, MessageType, } from '../types.js'; import { isAtCommand } from '../utils/commandUtils.js'; @@ -35,6 +42,17 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js'; import { useStateAndRef } from './useStateAndRef.js'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; +enum StreamProcessingStatus { + Completed, + PausedForConfirmation, + UserCancelled, + Error, +} + +/** + * Hook to manage the Gemini stream, handle user input, process commands, + * and interact with the Gemini API and history manager. + */ export const useGeminiStream = ( addItem: UseHistoryManagerReturn['addItem'], _clearItems: UseHistoryManagerReturn['clearItems'], @@ -82,240 +100,490 @@ export const useGeminiStream = ( } }); - const submitQuery = useCallback( - async (query: PartListUnion) => { - if (streamingState === StreamingState.Responding) return; - if (typeof query === 'string' && query.trim().length === 0) return; - - const userMessageTimestamp = Date.now(); - let queryToSendToGemini: PartListUnion | null = null; - - setShowHelp(false); + const prepareQueryForGemini = async ( + query: PartListUnion, + userMessageTimestamp: number, + signal: AbortSignal, + ): Promise<{ queryToSend: PartListUnion | null; shouldProceed: boolean }> => { + if (typeof query === 'string' && query.trim().length === 0) { + return { queryToSend: null, shouldProceed: false }; + } - abortControllerRef.current ??= new AbortController(); - const signal = abortControllerRef.current.signal; + let localQueryToSendToGemini: PartListUnion | null = null; - if (typeof query === 'string') { - const trimmedQuery = query.trim(); - onDebugMessage(`User query: '${trimmedQuery}'`); + if (typeof query === 'string') { + const trimmedQuery = query.trim(); + onDebugMessage(`User query: '${trimmedQuery}'`); - if (handleSlashCommand(trimmedQuery)) return; - if (handleShellCommand(trimmedQuery)) return; + // Handle UI-only commands first + if (handleSlashCommand(trimmedQuery)) { + return { queryToSend: null, shouldProceed: false }; + } + if (handleShellCommand(trimmedQuery)) { + return { queryToSend: null, shouldProceed: false }; + } - if (isAtCommand(trimmedQuery)) { - const atCommandResult = await handleAtCommand({ - query: trimmedQuery, - config, - addItem, - onDebugMessage, - messageId: userMessageTimestamp, - signal, - }); - if (!atCommandResult.shouldProceed) return; - queryToSendToGemini = atCommandResult.processedQuery; - } else { - addItem( - { type: MessageType.USER, text: trimmedQuery }, - userMessageTimestamp, - ); - queryToSendToGemini = trimmedQuery; + // Handle @-commands (which might involve tool calls) + if (isAtCommand(trimmedQuery)) { + const atCommandResult = await handleAtCommand({ + query: trimmedQuery, + config, + addItem, + onDebugMessage, + messageId: userMessageTimestamp, + signal, + }); + if (!atCommandResult.shouldProceed) { + return { queryToSend: null, shouldProceed: false }; } + localQueryToSendToGemini = atCommandResult.processedQuery; } else { - queryToSendToGemini = query; - } - - if (queryToSendToGemini === null) { - onDebugMessage( - 'Query processing resulted in null, not sending to Gemini.', + // Normal query for Gemini + addItem( + { type: MessageType.USER, text: trimmedQuery }, + userMessageTimestamp, ); - return; + localQueryToSendToGemini = trimmedQuery; } + } else { + // It's a function response (PartListUnion that isn't a string) + localQueryToSendToGemini = query; + } - const client = geminiClientRef.current; - if (!client) { - const errorMsg = 'Gemini client is not available.'; + if (localQueryToSendToGemini === null) { + onDebugMessage( + 'Query processing resulted in null, not sending to Gemini.', + ); + return { queryToSend: null, shouldProceed: false }; + } + return { queryToSend: localQueryToSendToGemini, shouldProceed: true }; + }; + + const ensureChatSession = async (): Promise<{ + client: GeminiClient | null; + chat: Chat | null; + }> => { + const currentClient = geminiClientRef.current; + if (!currentClient) { + const errorMsg = 'Gemini client is not available.'; + setInitError(errorMsg); + addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now()); + return { client: null, chat: null }; + } + + if (!chatSessionRef.current) { + try { + chatSessionRef.current = await currentClient.startChat(); + } catch (err: unknown) { + const errorMsg = `Failed to start chat: ${getErrorMessage(err)}`; setInitError(errorMsg); addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now()); - return; + setStreamingState(StreamingState.Idle); + return { client: currentClient, chat: null }; + } + } + return { client: currentClient, chat: chatSessionRef.current }; + }; + + // --- UI Helper Functions (used by event handlers) --- + const updateFunctionResponseUI = ( + toolResponse: ToolCallResponseInfo, + status: ToolCallStatus, + ) => { + setPendingHistoryItem((item) => + item?.type === 'tool_group' + ? { + ...item, + tools: item.tools.map((tool) => { + if (tool.callId === toolResponse.callId) { + return { + ...tool, + status, + resultDisplay: toolResponse.resultDisplay, + }; + } else { + return tool; + } + }), + } + : null, + ); + }; + + const updateConfirmingFunctionStatusUI = ( + callId: string, + confirmationDetails: ToolCallConfirmationDetails | undefined, + ) => { + if (pendingHistoryItemRef.current?.type !== 'tool_group') return; + setPendingHistoryItem((item) => + item?.type === 'tool_group' + ? { + ...item, + tools: item.tools.map((tool) => + tool.callId === callId + ? { + ...tool, + status: ToolCallStatus.Confirming, + confirmationDetails, + } + : tool, + ), + } + : null, + ); + }; + + // This function will be fully refactored in a later step + const wireConfirmationSubmission = ( + confirmationDetails: ServerToolCallConfirmationDetails, + ): ToolCallConfirmationDetails => { + const originalConfirmationDetails = confirmationDetails.details; + const request = confirmationDetails.request; + const resubmittingConfirm = async (outcome: ToolConfirmationOutcome) => { + originalConfirmationDetails.onConfirm(outcome); + if (pendingHistoryItemRef?.current?.type === 'tool_group') { + setPendingHistoryItem((item) => + item?.type === 'tool_group' + ? { + ...item, + tools: item.tools.map((tool) => + tool.callId === request.callId + ? { + ...tool, + confirmationDetails: undefined, + status: ToolCallStatus.Executing, + } + : tool, + ), + } + : item, + ); + refreshStatic(); } - if (!chatSessionRef.current) { + if (outcome === ToolConfirmationOutcome.Cancel) { + declineToolExecution( + 'User rejected function call.', + ToolCallStatus.Error, + request, + originalConfirmationDetails, + ); + } else { + const tool = toolRegistry.getTool(request.name); + if (!tool) { + throw new Error( + `Tool "${request.name}" not found or is not registered.`, + ); + } try { - chatSessionRef.current = await client.startChat(); - } catch (err: unknown) { - const errorMsg = `Failed to start chat: ${getErrorMessage(err)}`; - setInitError(errorMsg); - addItem({ type: MessageType.ERROR, text: errorMsg }, Date.now()); + abortControllerRef.current = new AbortController(); + const result = await tool.execute( + request.args, + abortControllerRef.current.signal, + ); + if (abortControllerRef.current.signal.aborted) { + declineToolExecution( + result.llmContent, + ToolCallStatus.Canceled, + request, + originalConfirmationDetails, + ); + return; + } + const functionResponse: Part = { + functionResponse: { + name: request.name, + id: request.callId, + response: { output: result.llmContent }, + }, + }; + const responseInfo: ToolCallResponseInfo = { + callId: request.callId, + responsePart: functionResponse, + resultDisplay: result.returnDisplay, + error: undefined, + }; + updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, Date.now()); + setPendingHistoryItem(null); + } setStreamingState(StreamingState.Idle); - return; + await submitQuery(functionResponse); // Recursive call + } finally { + if (streamingState !== StreamingState.WaitingForConfirmation) { + abortControllerRef.current = null; + } } } + }; - setStreamingState(StreamingState.Responding); - setInitError(null); - const chat = chatSessionRef.current; + // Extracted declineToolExecution to be part of wireConfirmationSubmission's closure + // or could be a standalone helper if more params are passed. + function declineToolExecution( + declineMessage: string, + status: ToolCallStatus, + request: ServerToolCallConfirmationDetails['request'], + originalDetails: ServerToolCallConfirmationDetails['details'], + ) { + let resultDisplay: ToolResultDisplay | undefined; + if ('fileDiff' in originalDetails) { + resultDisplay = { + fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff, + }; + } else { + resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`; + } + const functionResponse: Part = { + functionResponse: { + id: request.callId, + name: request.name, + response: { error: declineMessage }, + }, + }; + const responseInfo: ToolCallResponseInfo = { + callId: request.callId, + responsePart: functionResponse, + resultDisplay, + error: new Error(declineMessage), + }; + const history = chatSessionRef.current?.getHistory(); + if (history) { + history.push({ role: 'model', parts: [functionResponse] }); + } + updateFunctionResponseUI(responseInfo, status); + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, Date.now()); + setPendingHistoryItem(null); + } + setStreamingState(StreamingState.Idle); + } - try { - const stream = client.sendMessageStream( - chat, - queryToSendToGemini, - signal, + return { ...originalConfirmationDetails, onConfirm: resubmittingConfirm }; + }; + + // --- Stream Event Handlers --- + const handleContentEvent = ( + eventValue: ContentEvent['value'], + currentGeminiMessageBuffer: string, + userMessageTimestamp: number, + ): string => { + let newGeminiMessageBuffer = currentGeminiMessageBuffer + eventValue; + if ( + pendingHistoryItemRef.current?.type !== 'gemini' && + pendingHistoryItemRef.current?.type !== 'gemini_content' + ) { + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + } + setPendingHistoryItem({ type: 'gemini', text: '' }); + newGeminiMessageBuffer = eventValue; + } + // Split large messages for better rendering performance. Ideally, + // we should maximize the amount of output sent to <Static />. + const splitPoint = findLastSafeSplitPoint(newGeminiMessageBuffer); + if (splitPoint === newGeminiMessageBuffer.length) { + // Update the existing message with accumulated content + setPendingHistoryItem((item) => ({ + type: item?.type as 'gemini' | 'gemini_content', + text: newGeminiMessageBuffer, + })); + } else { + // This indicates that we need to split up this Gemini Message. + // Splitting a message is primarily a performance consideration. There is a + // <Static> component at the root of App.tsx which takes care of rendering + // content statically or dynamically. Everything but the last message is + // treated as static in order to prevent re-rendering an entire message history + // multiple times per-second (as streaming occurs). Prior to this change you'd + // see heavy flickering of the terminal. This ensures that larger messages get + // broken up so that there are more "statically" rendered. + const beforeText = newGeminiMessageBuffer.substring(0, splitPoint); + const afterText = newGeminiMessageBuffer.substring(splitPoint); + addItem( + { + type: pendingHistoryItemRef.current?.type as + | 'gemini' + | 'gemini_content', + text: beforeText, + }, + userMessageTimestamp, + ); + setPendingHistoryItem({ type: 'gemini_content', text: afterText }); + newGeminiMessageBuffer = afterText; + } + return newGeminiMessageBuffer; + }; + + const handleToolCallRequestEvent = ( + eventValue: ToolCallRequestEvent['value'], + userMessageTimestamp: number, + ) => { + const { callId, name, args } = eventValue; + const cliTool = toolRegistry.getTool(name); + if (!cliTool) { + console.error(`CLI Tool "${name}" not found!`); + return; // Skip this event if tool is not found + } + if (pendingHistoryItemRef.current?.type !== 'tool_group') { + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + } + setPendingHistoryItem({ type: 'tool_group', tools: [] }); + } + let description: string; + try { + description = cliTool.getDescription(args); + } catch (e) { + description = `Error: Unable to get description: ${getErrorMessage(e)}`; + } + const toolCallDisplay: IndividualToolCallDisplay = { + callId, + name: cliTool.displayName, + description, + status: ToolCallStatus.Pending, + resultDisplay: undefined, + confirmationDetails: undefined, + }; + setPendingHistoryItem((pending) => + pending?.type === 'tool_group' + ? { ...pending, tools: [...pending.tools, toolCallDisplay] } + : null, + ); + }; + + const handleToolCallResponseEvent = ( + eventValue: ToolCallResponseEvent['value'], + ) => { + const status = eventValue.error + ? ToolCallStatus.Error + : ToolCallStatus.Success; + updateFunctionResponseUI(eventValue, status); + }; + + const handleToolCallConfirmationEvent = ( + eventValue: ToolCallConfirmationEvent['value'], + ) => { + const confirmationDetails = wireConfirmationSubmission(eventValue); + updateConfirmingFunctionStatusUI( + eventValue.request.callId, + confirmationDetails, + ); + setStreamingState(StreamingState.WaitingForConfirmation); + }; + + const handleUserCancelledEvent = (userMessageTimestamp: number) => { + if (pendingHistoryItemRef.current) { + if (pendingHistoryItemRef.current.type === 'tool_group') { + const updatedTools = pendingHistoryItemRef.current.tools.map((tool) => + tool.status === ToolCallStatus.Pending || + tool.status === ToolCallStatus.Confirming || + tool.status === ToolCallStatus.Executing + ? { ...tool, status: ToolCallStatus.Canceled } + : tool, ); + const pendingItem: HistoryItemToolGroup = { + ...pendingHistoryItemRef.current, + tools: updatedTools, + }; + addItem(pendingItem, userMessageTimestamp); + } else { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + } + setPendingHistoryItem(null); + } + addItem( + { type: MessageType.INFO, text: 'User cancelled the request.' }, + userMessageTimestamp, + ); + setStreamingState(StreamingState.Idle); + }; - let geminiMessageBuffer = ''; + const handleErrorEvent = ( + eventValue: ErrorEvent['value'], + userMessageTimestamp: number, + ) => { + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + setPendingHistoryItem(null); + } + addItem( + { type: MessageType.ERROR, text: `[API Error: ${eventValue.message}]` }, + userMessageTimestamp, + ); + }; - for await (const event of stream) { - if (event.type === ServerGeminiEventType.Content) { - if ( - pendingHistoryItemRef.current?.type !== 'gemini' && - pendingHistoryItemRef.current?.type !== 'gemini_content' - ) { - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); - } - setPendingHistoryItem({ - type: 'gemini', - text: '', - }); - geminiMessageBuffer = ''; - } + const processGeminiStreamEvents = async ( + stream: AsyncIterable<GeminiEvent>, + userMessageTimestamp: number, + ): Promise<StreamProcessingStatus> => { + let geminiMessageBuffer = ''; - geminiMessageBuffer += event.value; + for await (const event of stream) { + if (event.type === ServerGeminiEventType.Content) { + geminiMessageBuffer = handleContentEvent( + event.value, + geminiMessageBuffer, + userMessageTimestamp, + ); + } else if (event.type === ServerGeminiEventType.ToolCallRequest) { + handleToolCallRequestEvent(event.value, userMessageTimestamp); + } else if (event.type === ServerGeminiEventType.ToolCallResponse) { + handleToolCallResponseEvent(event.value); + } else if (event.type === ServerGeminiEventType.ToolCallConfirmation) { + handleToolCallConfirmationEvent(event.value); + return StreamProcessingStatus.PausedForConfirmation; // Explicit return as this pauses the stream + } else if (event.type === ServerGeminiEventType.UserCancelled) { + handleUserCancelledEvent(userMessageTimestamp); + return StreamProcessingStatus.UserCancelled; + } else if (event.type === ServerGeminiEventType.Error) { + handleErrorEvent(event.value, userMessageTimestamp); + return StreamProcessingStatus.Error; + } + } + return StreamProcessingStatus.Completed; + }; - // Split large messages for better rendering performance. Ideally, - // we should maximize the amount of output sent to <Static />. - const splitPoint = findLastSafeSplitPoint(geminiMessageBuffer); - if (splitPoint === geminiMessageBuffer.length) { - // Update the existing message with accumulated content - setPendingHistoryItem((item) => ({ - type: item?.type as 'gemini' | 'gemini_content', - text: geminiMessageBuffer, - })); - } else { - // This indicates that we need to split up this Gemini Message. - // Splitting a message is primarily a performance consideration. There is a - // <Static> component at the root of App.tsx which takes care of rendering - // content statically or dynamically. Everything but the last message is - // treated as static in order to prevent re-rendering an entire message history - // multiple times per-second (as streaming occurs). Prior to this change you'd - // see heavy flickering of the terminal. This ensures that larger messages get - // broken up so that there are more "statically" rendered. - const beforeText = geminiMessageBuffer.substring(0, splitPoint); - const afterText = geminiMessageBuffer.substring(splitPoint); - geminiMessageBuffer = afterText; - addItem( - { - type: pendingHistoryItemRef.current?.type as - | 'gemini' - | 'gemini_content', - text: beforeText, - }, - userMessageTimestamp, - ); - setPendingHistoryItem({ - type: 'gemini_content', - text: afterText, - }); - } - } else if (event.type === ServerGeminiEventType.ToolCallRequest) { - const { callId, name, args } = event.value; - const cliTool = toolRegistry.getTool(name); - if (!cliTool) { - console.error(`CLI Tool "${name}" not found!`); - continue; - } + const submitQuery = useCallback( + async (query: PartListUnion) => { + if (streamingState === StreamingState.Responding) return; - if (pendingHistoryItemRef.current?.type !== 'tool_group') { - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); - } - setPendingHistoryItem({ - type: 'tool_group', - tools: [], - }); - } + const userMessageTimestamp = Date.now(); + setShowHelp(false); - let description: string; - try { - description = cliTool.getDescription(args); - } catch (e) { - description = `Error: Unable to get description: ${getErrorMessage(e)}`; - } + abortControllerRef.current ??= new AbortController(); + const signal = abortControllerRef.current.signal; - const toolCallDisplay: IndividualToolCallDisplay = { - callId, - name: cliTool.displayName, - description, - status: ToolCallStatus.Pending, - resultDisplay: undefined, - confirmationDetails: undefined, - }; + const { queryToSend, shouldProceed } = await prepareQueryForGemini( + query, + userMessageTimestamp, + signal, + ); - setPendingHistoryItem((pending) => - pending?.type === 'tool_group' - ? { - ...pending, - tools: [...pending.tools, toolCallDisplay], - } - : null, - ); - } else if (event.type === ServerGeminiEventType.ToolCallResponse) { - const status = event.value.error - ? ToolCallStatus.Error - : ToolCallStatus.Success; - updateFunctionResponseUI(event.value, status); - } else if ( - event.type === ServerGeminiEventType.ToolCallConfirmation - ) { - const confirmationDetails = wireConfirmationSubmission(event.value); - updateConfirmingFunctionStatusUI( - event.value.request.callId, - confirmationDetails, - ); - setStreamingState(StreamingState.WaitingForConfirmation); - return; - } else if (event.type === ServerGeminiEventType.UserCancelled) { - if (pendingHistoryItemRef.current) { - if (pendingHistoryItemRef.current.type === 'tool_group') { - const updatedTools = pendingHistoryItemRef.current.tools.map( - (tool) => { - if ( - tool.status === ToolCallStatus.Pending || - tool.status === ToolCallStatus.Confirming || - tool.status === ToolCallStatus.Executing - ) { - return { ...tool, status: ToolCallStatus.Canceled }; - } - return tool; - }, - ); - const pendingHistoryItem = pendingHistoryItemRef.current; - pendingHistoryItem.tools = updatedTools; - addItem(pendingHistoryItem, userMessageTimestamp); - } else { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); - } - setPendingHistoryItem(null); - } - addItem( - { type: MessageType.INFO, text: 'User cancelled the request.' }, - userMessageTimestamp, - ); - setStreamingState(StreamingState.Idle); - return; - } else if (event.type === ServerGeminiEventType.Error) { - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); - setPendingHistoryItem(null); - } - addItem( - { - type: MessageType.ERROR, - text: `[API Error: ${event.value.message}]`, - }, - userMessageTimestamp, - ); - } + if (!shouldProceed || queryToSend === null) { + return; + } + + const { client, chat } = await ensureChatSession(); + + if (!client || !chat) { + return; + } + + setStreamingState(StreamingState.Responding); + setInitError(null); + + try { + const stream = client.sendMessageStream(chat, queryToSend, signal); + const processingStatus = await processGeminiStreamEvents( + stream, + userMessageTimestamp, + ); + + if ( + processingStatus === StreamProcessingStatus.PausedForConfirmation || + processingStatus === StreamProcessingStatus.UserCancelled + ) { + return; } if (pendingHistoryItemRef.current) { @@ -323,7 +591,12 @@ export const useGeminiStream = ( setPendingHistoryItem(null); } - setStreamingState(StreamingState.Idle); + if ( + processingStatus === StreamProcessingStatus.Completed || + processingStatus === StreamProcessingStatus.Error + ) { + setStreamingState(StreamingState.Idle); + } } catch (error: unknown) { if (!isNodeError(error) || error.name !== 'AbortError') { addItem( @@ -336,191 +609,12 @@ export const useGeminiStream = ( } setStreamingState(StreamingState.Idle); } finally { - abortControllerRef.current = null; - } - - function updateConfirmingFunctionStatusUI( - callId: string, - confirmationDetails: ToolCallConfirmationDetails | undefined, - ) { - if (pendingHistoryItemRef.current?.type !== 'tool_group') return; - setPendingHistoryItem((item) => - item?.type === 'tool_group' - ? { - ...item, - tools: item.tools.map((tool) => - tool.callId === callId - ? { - ...tool, - status: ToolCallStatus.Confirming, - confirmationDetails, - } - : tool, - ), - } - : null, - ); - } - - function updateFunctionResponseUI( - toolResponse: ToolCallResponseInfo, - status: ToolCallStatus, - ) { - setPendingHistoryItem((item) => - item?.type === 'tool_group' - ? { - ...item, - tools: item.tools.map((tool) => { - if (tool.callId === toolResponse.callId) { - return { - ...tool, - status, - resultDisplay: toolResponse.resultDisplay, - }; - } else { - return tool; - } - }), - } - : null, - ); - } - - function wireConfirmationSubmission( - confirmationDetails: ServerToolCallConfirmationDetails, - ): ToolCallConfirmationDetails { - const originalConfirmationDetails = confirmationDetails.details; - const request = confirmationDetails.request; - const resubmittingConfirm = async ( - outcome: ToolConfirmationOutcome, - ) => { - originalConfirmationDetails.onConfirm(outcome); - - if (pendingHistoryItemRef?.current?.type === 'tool_group') { - setPendingHistoryItem((item) => - item?.type === 'tool_group' - ? { - ...item, - tools: item.tools.map((tool) => - tool.callId === request.callId - ? { - ...tool, - confirmationDetails: undefined, - status: ToolCallStatus.Executing, - } - : tool, - ), - } - : item, - ); - refreshStatic(); - } - - if (outcome === ToolConfirmationOutcome.Cancel) { - declineToolExecution( - 'User rejected function call.', - ToolCallStatus.Error, - ); - } else { - const tool = toolRegistry.getTool(request.name); - if (!tool) { - throw new Error( - `Tool "${request.name}" not found or is not registered.`, - ); - } - - try { - abortControllerRef.current = new AbortController(); - const result = await tool.execute( - request.args, - abortControllerRef.current.signal, - ); - - if (abortControllerRef.current.signal.aborted) { - declineToolExecution( - result.llmContent, - ToolCallStatus.Canceled, - ); - return; - } - - const functionResponse: Part = { - functionResponse: { - name: request.name, - id: request.callId, - response: { output: result.llmContent }, - }, - }; - - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responsePart: functionResponse, - resultDisplay: result.returnDisplay, - error: undefined, - }; - updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, Date.now()); - setPendingHistoryItem(null); - } - setStreamingState(StreamingState.Idle); - await submitQuery(functionResponse); - } finally { - abortControllerRef.current = null; - } - } - - function declineToolExecution( - declineMessage: string, - status: ToolCallStatus, - ) { - let resultDisplay: ToolResultDisplay | undefined; - if ('fileDiff' in originalConfirmationDetails) { - resultDisplay = { - fileDiff: ( - originalConfirmationDetails as ToolEditConfirmationDetails - ).fileDiff, - }; - } else { - resultDisplay = `~~${(originalConfirmationDetails as ToolExecuteConfirmationDetails).command}~~`; - } - const functionResponse: Part = { - functionResponse: { - id: request.callId, - name: request.name, - response: { error: declineMessage }, - }, - }; - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responsePart: functionResponse, - resultDisplay, - error: new Error(declineMessage), - }; - - const history = chatSessionRef.current?.getHistory(); - if (history) { - history.push({ - role: 'model', - parts: [functionResponse], - }); - } - - updateFunctionResponseUI(responseInfo, status); - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, Date.now()); - setPendingHistoryItem(null); - } - setStreamingState(StreamingState.Idle); - } - }; - - return { - ...originalConfirmationDetails, - onConfirm: resubmittingConfirm, - }; + if (streamingState !== StreamingState.WaitingForConfirmation) { + abortControllerRef.current = null; + } } }, + // eslint-disable-next-line react-hooks/exhaustive-deps [ streamingState, setShowHelp, @@ -528,11 +622,10 @@ export const useGeminiStream = ( handleShellCommand, config, addItem, - pendingHistoryItemRef, - setPendingHistoryItem, - toolRegistry, - refreshStatic, onDebugMessage, + refreshStatic, + setInitError, + setStreamingState, ], ); |
