diff options
Diffstat (limited to 'packages/cli/src/ui/hooks/useGeminiStream.ts')
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 441 |
1 files changed, 239 insertions, 202 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 45415f39..9dcb005b 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -42,7 +42,6 @@ import { useToolScheduler, mapToDisplay } from './useToolScheduler.js'; enum StreamProcessingStatus { Completed, - PausedForConfirmation, UserCancelled, Error, } @@ -103,6 +102,21 @@ export const useGeminiStream = ( config, ); + const streamingState = useMemo(() => { + if (toolCalls.some((t) => t.status === 'awaiting_approval')) { + return StreamingState.WaitingForConfirmation; + } + if ( + isResponding || + toolCalls.some( + (t) => t.status === 'executing' || t.status === 'scheduled', + ) + ) { + return StreamingState.Responding; + } + return StreamingState.Idle; + }, [isResponding, toolCalls]); + useEffect(() => { setInitError(null); if (!geminiClientRef.current) { @@ -123,85 +137,100 @@ export const useGeminiStream = ( } }); - const prepareQueryForGemini = async ( - query: PartListUnion, - userMessageTimestamp: number, - signal: AbortSignal, - ): Promise<{ queryToSend: PartListUnion | null; shouldProceed: boolean }> => { - if (typeof query === 'string' && query.trim().length === 0) { - return { queryToSend: null, shouldProceed: false }; - } + const prepareQueryForGemini = useCallback( + async ( + query: PartListUnion, + userMessageTimestamp: number, + signal: AbortSignal, + ): Promise<{ + queryToSend: PartListUnion | null; + shouldProceed: boolean; + }> => { + if (typeof query === 'string' && query.trim().length === 0) { + return { queryToSend: null, shouldProceed: false }; + } - let localQueryToSendToGemini: PartListUnion | null = null; + let localQueryToSendToGemini: PartListUnion | null = null; - if (typeof query === 'string') { - const trimmedQuery = query.trim(); - onDebugMessage(`User query: '${trimmedQuery}'`); - await logger?.logMessage(MessageSenderType.USER, trimmedQuery); + if (typeof query === 'string') { + const trimmedQuery = query.trim(); + onDebugMessage(`User query: '${trimmedQuery}'`); + await logger?.logMessage(MessageSenderType.USER, trimmedQuery); - // Handle UI-only commands first - const slashCommandResult = handleSlashCommand(trimmedQuery); - if (typeof slashCommandResult === 'boolean' && slashCommandResult) { - // Command was handled, and it doesn't require a tool call from here - return { queryToSend: null, shouldProceed: false }; - } else if ( - typeof slashCommandResult === 'object' && - slashCommandResult.shouldScheduleTool - ) { - // Slash command wants to schedule a tool call (e.g., /memory add) - const { toolName, toolArgs } = slashCommandResult; - if (toolName && toolArgs) { - const toolCallRequest: ToolCallRequestInfo = { - callId: `${toolName}-${Date.now()}-${Math.random().toString(16).slice(2)}`, - name: toolName, - args: toolArgs, - }; - schedule([toolCallRequest]); // schedule expects an array or single object + // Handle UI-only commands first + const slashCommandResult = handleSlashCommand(trimmedQuery); + if (typeof slashCommandResult === 'boolean' && slashCommandResult) { + // Command was handled, and it doesn't require a tool call from here + return { queryToSend: null, shouldProceed: false }; + } else if ( + typeof slashCommandResult === 'object' && + slashCommandResult.shouldScheduleTool + ) { + // Slash command wants to schedule a tool call (e.g., /memory add) + const { toolName, toolArgs } = slashCommandResult; + if (toolName && toolArgs) { + const toolCallRequest: ToolCallRequestInfo = { + callId: `${toolName}-${Date.now()}-${Math.random().toString(16).slice(2)}`, + name: toolName, + args: toolArgs, + }; + schedule([toolCallRequest]); // schedule expects an array or single object + } + return { queryToSend: null, shouldProceed: false }; // Handled by scheduling the tool } - return { queryToSend: null, shouldProceed: false }; // Handled by scheduling the tool - } - if (shellModeActive && handleShellCommand(trimmedQuery)) { - return { queryToSend: null, shouldProceed: false }; - } - - // Handle @-commands (which might involve tool calls) - if (isAtCommand(trimmedQuery)) { - const atCommandResult = await handleAtCommand({ - query: trimmedQuery, - config, - addItem, - onDebugMessage, - messageId: userMessageTimestamp, - signal, - }); - if (!atCommandResult.shouldProceed) { + if (shellModeActive && handleShellCommand(trimmedQuery)) { return { queryToSend: null, shouldProceed: false }; } - localQueryToSendToGemini = atCommandResult.processedQuery; + + // Handle @-commands (which might involve tool calls) + if (isAtCommand(trimmedQuery)) { + const atCommandResult = await handleAtCommand({ + query: trimmedQuery, + config, + addItem, + onDebugMessage, + messageId: userMessageTimestamp, + signal, + }); + if (!atCommandResult.shouldProceed) { + return { queryToSend: null, shouldProceed: false }; + } + localQueryToSendToGemini = atCommandResult.processedQuery; + } else { + // Normal query for Gemini + addItem( + { type: MessageType.USER, text: trimmedQuery }, + userMessageTimestamp, + ); + localQueryToSendToGemini = trimmedQuery; + } } else { - // Normal query for Gemini - addItem( - { type: MessageType.USER, text: trimmedQuery }, - userMessageTimestamp, - ); - localQueryToSendToGemini = trimmedQuery; + // It's a function response (PartListUnion that isn't a string) + localQueryToSendToGemini = query; } - } else { - // It's a function response (PartListUnion that isn't a string) - localQueryToSendToGemini = query; - } - if (localQueryToSendToGemini === null) { - onDebugMessage( - 'Query processing resulted in null, not sending to Gemini.', - ); - return { queryToSend: null, shouldProceed: false }; - } - return { queryToSend: localQueryToSendToGemini, shouldProceed: true }; - }; + if (localQueryToSendToGemini === null) { + onDebugMessage( + 'Query processing resulted in null, not sending to Gemini.', + ); + return { queryToSend: null, shouldProceed: false }; + } + return { queryToSend: localQueryToSendToGemini, shouldProceed: true }; + }, + [ + config, + addItem, + onDebugMessage, + handleShellCommand, + handleSlashCommand, + logger, + shellModeActive, + schedule, + ], + ); - const ensureChatSession = async (): Promise<{ + const ensureChatSession = useCallback(async (): Promise<{ client: GeminiClient | null; chat: Chat | null; }> => { @@ -224,7 +253,7 @@ export const useGeminiStream = ( } } return { client: currentClient, chat: chatSessionRef.current }; - }; + }, [addItem]); // --- UI Helper Functions (used by event handlers) --- const updateFunctionResponseUI = ( @@ -285,6 +314,7 @@ export const useGeminiStream = ( history.push({ role: 'model', parts: [functionResponse] }); } updateFunctionResponseUI(responseInfo, status); + if (pendingHistoryItemRef.current) { addItem(pendingHistoryItemRef.current, Date.now()); setPendingHistoryItem(null); @@ -293,138 +323,151 @@ export const useGeminiStream = ( } // --- Stream Event Handlers --- - const handleContentEvent = ( - eventValue: ContentEvent['value'], - currentGeminiMessageBuffer: string, - userMessageTimestamp: number, - ): string => { - let newGeminiMessageBuffer = currentGeminiMessageBuffer + eventValue; - if ( - pendingHistoryItemRef.current?.type !== 'gemini' && - pendingHistoryItemRef.current?.type !== 'gemini_content' - ) { + + const handleContentEvent = useCallback( + ( + eventValue: ContentEvent['value'], + currentGeminiMessageBuffer: string, + userMessageTimestamp: number, + ): string => { + let newGeminiMessageBuffer = currentGeminiMessageBuffer + eventValue; + if ( + pendingHistoryItemRef.current?.type !== 'gemini' && + pendingHistoryItemRef.current?.type !== 'gemini_content' + ) { + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + } + setPendingHistoryItem({ type: 'gemini', text: '' }); + newGeminiMessageBuffer = eventValue; + } + // Split large messages for better rendering performance. Ideally, + // we should maximize the amount of output sent to <Static />. + const splitPoint = findLastSafeSplitPoint(newGeminiMessageBuffer); + if (splitPoint === newGeminiMessageBuffer.length) { + // Update the existing message with accumulated content + setPendingHistoryItem((item) => ({ + type: item?.type as 'gemini' | 'gemini_content', + text: newGeminiMessageBuffer, + })); + } else { + // This indicates that we need to split up this Gemini Message. + // Splitting a message is primarily a performance consideration. There is a + // <Static> component at the root of App.tsx which takes care of rendering + // content statically or dynamically. Everything but the last message is + // treated as static in order to prevent re-rendering an entire message history + // multiple times per-second (as streaming occurs). Prior to this change you'd + // see heavy flickering of the terminal. This ensures that larger messages get + // broken up so that there are more "statically" rendered. + const beforeText = newGeminiMessageBuffer.substring(0, splitPoint); + const afterText = newGeminiMessageBuffer.substring(splitPoint); + addItem( + { + type: pendingHistoryItemRef.current?.type as + | 'gemini' + | 'gemini_content', + text: beforeText, + }, + userMessageTimestamp, + ); + setPendingHistoryItem({ type: 'gemini_content', text: afterText }); + newGeminiMessageBuffer = afterText; + } + return newGeminiMessageBuffer; + }, + [addItem, pendingHistoryItemRef, setPendingHistoryItem], + ); + + const handleUserCancelledEvent = useCallback( + (userMessageTimestamp: number) => { if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); + if (pendingHistoryItemRef.current.type === 'tool_group') { + const updatedTools = pendingHistoryItemRef.current.tools.map( + (tool) => + tool.status === ToolCallStatus.Pending || + tool.status === ToolCallStatus.Confirming || + tool.status === ToolCallStatus.Executing + ? { ...tool, status: ToolCallStatus.Canceled } + : tool, + ); + const pendingItem: HistoryItemToolGroup = { + ...pendingHistoryItemRef.current, + tools: updatedTools, + }; + addItem(pendingItem, userMessageTimestamp); + } else { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + } + setPendingHistoryItem(null); } - setPendingHistoryItem({ type: 'gemini', text: '' }); - newGeminiMessageBuffer = eventValue; - } - // Split large messages for better rendering performance. Ideally, - // we should maximize the amount of output sent to <Static />. - const splitPoint = findLastSafeSplitPoint(newGeminiMessageBuffer); - if (splitPoint === newGeminiMessageBuffer.length) { - // Update the existing message with accumulated content - setPendingHistoryItem((item) => ({ - type: item?.type as 'gemini' | 'gemini_content', - text: newGeminiMessageBuffer, - })); - } else { - // This indicates that we need to split up this Gemini Message. - // Splitting a message is primarily a performance consideration. There is a - // <Static> component at the root of App.tsx which takes care of rendering - // content statically or dynamically. Everything but the last message is - // treated as static in order to prevent re-rendering an entire message history - // multiple times per-second (as streaming occurs). Prior to this change you'd - // see heavy flickering of the terminal. This ensures that larger messages get - // broken up so that there are more "statically" rendered. - const beforeText = newGeminiMessageBuffer.substring(0, splitPoint); - const afterText = newGeminiMessageBuffer.substring(splitPoint); addItem( - { - type: pendingHistoryItemRef.current?.type as - | 'gemini' - | 'gemini_content', - text: beforeText, - }, + { type: MessageType.INFO, text: 'User cancelled the request.' }, userMessageTimestamp, ); - setPendingHistoryItem({ type: 'gemini_content', text: afterText }); - newGeminiMessageBuffer = afterText; - } - return newGeminiMessageBuffer; - }; + setIsResponding(false); + cancel(); + }, + [addItem, pendingHistoryItemRef, setPendingHistoryItem, cancel], + ); - const handleUserCancelledEvent = (userMessageTimestamp: number) => { - if (pendingHistoryItemRef.current) { - if (pendingHistoryItemRef.current.type === 'tool_group') { - const updatedTools = pendingHistoryItemRef.current.tools.map((tool) => - tool.status === ToolCallStatus.Pending || - tool.status === ToolCallStatus.Confirming || - tool.status === ToolCallStatus.Executing - ? { ...tool, status: ToolCallStatus.Canceled } - : tool, - ); - const pendingItem: HistoryItemToolGroup = { - ...pendingHistoryItemRef.current, - tools: updatedTools, - }; - addItem(pendingItem, userMessageTimestamp); - } else { + const handleErrorEvent = useCallback( + (eventValue: ErrorEvent['value'], userMessageTimestamp: number) => { + if (pendingHistoryItemRef.current) { addItem(pendingHistoryItemRef.current, userMessageTimestamp); + setPendingHistoryItem(null); } - setPendingHistoryItem(null); - } - addItem( - { type: MessageType.INFO, text: 'User cancelled the request.' }, - userMessageTimestamp, - ); - setIsResponding(false); - cancel(); - }; - - const handleErrorEvent = ( - eventValue: ErrorEvent['value'], - userMessageTimestamp: number, - ) => { - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); - setPendingHistoryItem(null); - } - addItem( - { type: MessageType.ERROR, text: `[API Error: ${eventValue.message}]` }, - userMessageTimestamp, - ); - }; + addItem( + { type: MessageType.ERROR, text: `[API Error: ${eventValue.message}]` }, + userMessageTimestamp, + ); + }, + [addItem, pendingHistoryItemRef, setPendingHistoryItem], + ); - const processGeminiStreamEvents = async ( - stream: AsyncIterable<GeminiEvent>, - userMessageTimestamp: number, - ): Promise<StreamProcessingStatus> => { - let geminiMessageBuffer = ''; - const toolCallRequests: ToolCallRequestInfo[] = []; - for await (const event of stream) { - if (event.type === ServerGeminiEventType.Content) { - geminiMessageBuffer = handleContentEvent( - event.value, - geminiMessageBuffer, - userMessageTimestamp, - ); - } else if (event.type === ServerGeminiEventType.ToolCallRequest) { - toolCallRequests.push(event.value); - } else if (event.type === ServerGeminiEventType.UserCancelled) { - handleUserCancelledEvent(userMessageTimestamp); - cancel(); - return StreamProcessingStatus.UserCancelled; - } else if (event.type === ServerGeminiEventType.Error) { - handleErrorEvent(event.value, userMessageTimestamp); - return StreamProcessingStatus.Error; + const processGeminiStreamEvents = useCallback( + async ( + stream: AsyncIterable<GeminiEvent>, + userMessageTimestamp: number, + ): Promise<StreamProcessingStatus> => { + let geminiMessageBuffer = ''; + const toolCallRequests: ToolCallRequestInfo[] = []; + for await (const event of stream) { + if (event.type === ServerGeminiEventType.Content) { + geminiMessageBuffer = handleContentEvent( + event.value, + geminiMessageBuffer, + userMessageTimestamp, + ); + } else if (event.type === ServerGeminiEventType.ToolCallRequest) { + toolCallRequests.push(event.value); + } else if (event.type === ServerGeminiEventType.UserCancelled) { + handleUserCancelledEvent(userMessageTimestamp); + cancel(); + return StreamProcessingStatus.UserCancelled; + } else if (event.type === ServerGeminiEventType.Error) { + handleErrorEvent(event.value, userMessageTimestamp); + return StreamProcessingStatus.Error; + } } - } - schedule(toolCallRequests); - return StreamProcessingStatus.Completed; - }; - - const streamingState: StreamingState = - isResponding || - toolCalls.some( - (t) => t.status === 'awaiting_approval' || t.status === 'executing', - ) - ? StreamingState.Responding - : StreamingState.Idle; + schedule(toolCallRequests); + return StreamProcessingStatus.Completed; + }, + [ + handleContentEvent, + handleUserCancelledEvent, + cancel, + handleErrorEvent, + schedule, + ], + ); const submitQuery = useCallback( async (query: PartListUnion) => { - if (isResponding) return; + if ( + streamingState === StreamingState.Responding || + streamingState === StreamingState.WaitingForConfirmation + ) + return; const userMessageTimestamp = Date.now(); setShowHelp(false); @@ -458,10 +501,7 @@ export const useGeminiStream = ( userMessageTimestamp, ); - if ( - processingStatus === StreamProcessingStatus.PausedForConfirmation || - processingStatus === StreamProcessingStatus.UserCancelled - ) { + if (processingStatus === StreamProcessingStatus.UserCancelled) { return; } @@ -484,19 +524,16 @@ export const useGeminiStream = ( setIsResponding(false); } }, - // eslint-disable-next-line react-hooks/exhaustive-deps [ - isResponding, setShowHelp, - handleSlashCommand, - shellModeActive, - handleShellCommand, - config, addItem, - onDebugMessage, - refreshStatic, setInitError, - logger, + ensureChatSession, + prepareQueryForGemini, + processGeminiStreamEvents, + setPendingHistoryItem, + pendingHistoryItemRef, + streamingState, ], ); |
