diff options
Diffstat (limited to 'packages/cli/src/ui/hooks/useGeminiStream.ts')
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 367 |
1 files changed, 138 insertions, 229 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 75114f77..b7ed771e 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -8,7 +8,7 @@ import { useState, useRef, useCallback, useEffect } from 'react'; import { useInput } from 'ink'; import { GeminiClient, - GeminiEventType as ServerGeminiEventType, // Rename to avoid conflict + GeminiEventType as ServerGeminiEventType, getErrorMessage, isNodeError, Config, @@ -32,21 +32,16 @@ import { useSlashCommandProcessor } from './slashCommandProcessor.js'; import { useShellCommandProcessor } from './shellCommandProcessor.js'; import { handleAtCommand } from './atCommandProcessor.js'; import { findSafeSplitPoint } from '../utils/markdownUtilities.js'; +import { UseHistoryManagerReturn } from './useHistoryManager.js'; -const addHistoryItem = ( - setHistory: React.Dispatch<React.SetStateAction<HistoryItem[]>>, - itemData: Omit<HistoryItem, 'id'>, - id: number, -) => { - setHistory((prevHistory) => [ - ...prevHistory, - { ...itemData, id } as HistoryItem, - ]); -}; - -// Hook now accepts apiKey and model +/** + * Hook to manage the Gemini stream, handle user input, process commands, + * and interact with the Gemini API and history manager. + */ export const useGeminiStream = ( - setHistory: React.Dispatch<React.SetStateAction<HistoryItem[]>>, + addItem: UseHistoryManagerReturn['addItem'], + updateItem: UseHistoryManagerReturn['updateItem'], + clearItems: UseHistoryManagerReturn['clearItems'], refreshStatic: () => void, setShowHelp: React.Dispatch<React.SetStateAction<boolean>>, config: Config, @@ -61,99 +56,56 @@ export const useGeminiStream = ( const abortControllerRef = useRef<AbortController | null>(null); const chatSessionRef = useRef<Chat | null>(null); const geminiClientRef = useRef<GeminiClient | null>(null); - const messageIdCounterRef = useRef(0); const currentGeminiMessageIdRef = useRef<number | null>(null); - // ID Generation Callback - const getNextMessageId = useCallback((baseTimestamp: number): number => { - // Increment *before* adding to ensure uniqueness against the base timestamp - messageIdCounterRef.current += 1; - return baseTimestamp + messageIdCounterRef.current; - }, []); - - // Instantiate command processors const { handleSlashCommand, slashCommands } = useSlashCommandProcessor( - setHistory, + addItem, + clearItems, refreshStatic, setShowHelp, setDebugMessage, - getNextMessageId, openThemeDialog, ); const { handleShellCommand } = useShellCommandProcessor( - setHistory, + addItem, setStreamingState, setDebugMessage, - getNextMessageId, config, ); - // Initialize Client Effect - uses props now useEffect(() => { setInitError(null); if (!geminiClientRef.current) { try { geminiClientRef.current = new GeminiClient(config); } catch (error: unknown) { - setInitError( - `Failed to initialize client: ${getErrorMessage(error) || 'Unknown error'}`, - ); + const errorMsg = `Failed to initialize client: ${getErrorMessage(error) || 'Unknown error'}`; + setInitError(errorMsg); + addItem({ type: 'error', text: errorMsg }, Date.now()); } } - }, [config]); + }, [config, addItem]); - // Input Handling Effect (remains the same) - useInput((input, key) => { + useInput((_input, key) => { if (streamingState === StreamingState.Responding && key.escape) { abortControllerRef.current?.abort(); } }); - // Helper function to update Gemini message content const updateGeminiMessage = useCallback( (messageId: number, newContent: string) => { - setHistory((prevHistory) => - prevHistory.map((item) => - item.id === messageId && item.type === 'gemini' - ? { ...item, text: newContent } - : item, - ), - ); + updateItem(messageId, { text: newContent }); }, - [setHistory], + [updateItem], ); - // Helper function to update Gemini message content - const updateAndAddGeminiMessageContent = useCallback( - ( - messageId: number, - previousContent: string, - nextId: number, - nextContent: string, - ) => { - setHistory((prevHistory) => { - const beforeNextHistory = prevHistory.map((item) => - item.id === messageId ? { ...item, text: previousContent } : item, - ); - - return [ - ...beforeNextHistory, - { id: nextId, type: 'gemini_content', text: nextContent }, - ]; - }); - }, - [setHistory], - ); - - // Improved submit query function const submitQuery = useCallback( async (query: PartListUnion) => { if (streamingState === StreamingState.Responding) return; if (typeof query === 'string' && query.trim().length === 0) return; const userMessageTimestamp = Date.now(); - messageIdCounterRef.current = 0; // Reset counter for this new submission let queryToSendToGemini: PartListUnion | null = null; setShowHelp(false); @@ -162,50 +114,33 @@ export const useGeminiStream = ( const trimmedQuery = query.trim(); setDebugMessage(`User query: '${trimmedQuery}'`); - // 1. Check for Slash Commands (/) - if (handleSlashCommand(trimmedQuery)) { - return; - } - - // 2. Check for Shell Commands (! or $) - if (handleShellCommand(trimmedQuery)) { - return; - } + // Handle UI-only commands first + if (handleSlashCommand(trimmedQuery)) return; + if (handleShellCommand(trimmedQuery)) return; - // 3. Check for @ Commands using the utility function + // Handle @-commands (which might involve tool calls) if (isAtCommand(trimmedQuery)) { const atCommandResult = await handleAtCommand({ query: trimmedQuery, config, - setHistory, + addItem, + updateItem, setDebugMessage, - getNextMessageId, - userMessageTimestamp, + messageId: userMessageTimestamp, }); - - if (!atCommandResult.shouldProceed) { - return; // @ command handled it (e.g., error) or decided not to proceed - } + if (!atCommandResult.shouldProceed) return; queryToSendToGemini = atCommandResult.processedQuery; - // User message and tool UI were added by handleAtCommand } else { - // 4. It's a normal query for Gemini - addHistoryItem( - setHistory, - { type: 'user', text: trimmedQuery }, - userMessageTimestamp, - ); + // Normal query for Gemini + addItem({ type: 'user', text: trimmedQuery }, userMessageTimestamp); queryToSendToGemini = trimmedQuery; } } else { - // 5. It's a function response (PartListUnion that isn't a string) - // Tool call/response UI handles history. Always proceed. + // It's a function response (PartListUnion that isn't a string) queryToSendToGemini = query; } - // --- Proceed to Gemini API call --- if (queryToSendToGemini === null) { - // Should only happen if @ command failed and returned null query setDebugMessage( 'Query processing resulted in null, not sending to Gemini.', ); @@ -214,7 +149,9 @@ export const useGeminiStream = ( const client = geminiClientRef.current; if (!client) { - setInitError('Gemini client is not available.'); + const errorMsg = 'Gemini client is not available.'; + setInitError(errorMsg); + addItem({ type: 'error', text: errorMsg }, Date.now()); return; } @@ -222,7 +159,9 @@ export const useGeminiStream = ( try { chatSessionRef.current = await client.startChat(); } catch (err: unknown) { - setInitError(`Failed to start chat: ${getErrorMessage(err)}`); + const errorMsg = `Failed to start chat: ${getErrorMessage(err)}`; + setInitError(errorMsg); + addItem({ type: 'error', text: errorMsg }, Date.now()); setStreamingState(StreamingState.Idle); return; } @@ -231,51 +170,39 @@ export const useGeminiStream = ( setStreamingState(StreamingState.Responding); setInitError(null); const chat = chatSessionRef.current; - let currentToolGroupId: number | null = null; + let currentToolGroupMessageId: number | null = null; try { abortControllerRef.current = new AbortController(); const signal = abortControllerRef.current.signal; - // Use the determined query for the Gemini call const stream = client.sendMessageStream( chat, queryToSendToGemini, signal, ); - // Process the stream events from the server logic - let currentGeminiText = ''; // To accumulate message content + let currentGeminiText = ''; let hasInitialGeminiResponse = false; for await (const event of stream) { if (signal.aborted) break; if (event.type === ServerGeminiEventType.Content) { - // For content events, accumulate the text and update an existing message or create a new one currentGeminiText += event.value; - - // Reset group because we're now adding a user message to the history. If we didn't reset the - // group here then any subsequent tool calls would get grouped before this message resulting in - // a misordering of history. - currentToolGroupId = null; + currentToolGroupMessageId = null; // Reset group on new text content if (!hasInitialGeminiResponse) { - // Create a new Gemini message if this is the first content event hasInitialGeminiResponse = true; - const eventTimestamp = getNextMessageId(userMessageTimestamp); - currentGeminiMessageIdRef.current = eventTimestamp; - - addHistoryItem( - setHistory, + const eventId = addItem( { type: 'gemini', text: currentGeminiText }, - eventTimestamp, + userMessageTimestamp, ); + currentGeminiMessageIdRef.current = eventId; } else if (currentGeminiMessageIdRef.current !== null) { + // Split large messages for better rendering performance const splitPoint = findSafeSplitPoint(currentGeminiText); - if (splitPoint === currentGeminiText.length) { - // Update the existing message with accumulated content updateGeminiMessage( currentGeminiMessageIdRef.current, currentGeminiText, @@ -291,40 +218,33 @@ export const useGeminiStream = ( // broken up so that there are more "statically" rendered. const originalMessageRef = currentGeminiMessageIdRef.current; const beforeText = currentGeminiText.substring(0, splitPoint); - - currentGeminiMessageIdRef.current = - getNextMessageId(userMessageTimestamp); const afterText = currentGeminiText.substring(splitPoint); - currentGeminiText = afterText; - updateAndAddGeminiMessageContent( - originalMessageRef, - beforeText, - currentGeminiMessageIdRef.current, - afterText, + currentGeminiText = afterText; // Continue accumulating from split point + updateItem(originalMessageRef, { text: beforeText }); + const nextId = addItem( + { type: 'gemini_content', text: afterText }, + userMessageTimestamp, ); + currentGeminiMessageIdRef.current = nextId; } } } else if (event.type === ServerGeminiEventType.ToolCallRequest) { - // Reset the Gemini message tracking for the next response currentGeminiText = ''; hasInitialGeminiResponse = false; currentGeminiMessageIdRef.current = null; const { callId, name, args } = event.value; - - const cliTool = toolRegistry.getTool(name); // Get the full CLI tool + const cliTool = toolRegistry.getTool(name); if (!cliTool) { console.error(`CLI Tool "${name}" not found!`); continue; } - if (currentToolGroupId === null) { - currentToolGroupId = getNextMessageId(userMessageTimestamp); - // Add explicit cast to Omit<HistoryItem, 'id'> - addHistoryItem( - setHistory, + // Create a new tool group if needed + if (currentToolGroupMessageId === null) { + currentToolGroupMessageId = addItem( { type: 'tool_group', tools: [] } as Omit<HistoryItem, 'id'>, - currentToolGroupId, + userMessageTimestamp, ); } @@ -335,7 +255,6 @@ export const useGeminiStream = ( description = `Error: Unable to get description: ${getErrorMessage(e)}`; } - // Create the UI display object matching IndividualToolCallDisplay const toolCallDisplay: IndividualToolCallDisplay = { callId, name: cliTool.displayName, @@ -345,25 +264,27 @@ export const useGeminiStream = ( confirmationDetails: undefined, }; - // Add pending tool call to the UI history group - setHistory((prevHistory) => - prevHistory.map((item) => { - if ( - item.id === currentToolGroupId && - item.type === 'tool_group' - ) { - // Ensure item.tools exists and is an array before spreading - const currentTools = Array.isArray(item.tools) - ? item.tools - : []; + // Add the pending tool call to the current group + if (currentToolGroupMessageId !== null) { + updateItem( + currentToolGroupMessageId, + ( + currentItem: HistoryItem, + ): Partial<Omit<HistoryItem, 'id'>> => { + if (currentItem?.type !== 'tool_group') { + console.error( + `Attempted to update non-tool-group item ${currentItem?.id} as tool group.`, + ); + return currentItem as Partial<Omit<HistoryItem, 'id'>>; + } + const currentTools = currentItem.tools; return { - ...item, - tools: [...currentTools, toolCallDisplay], // Add the complete display object - }; - } - return item; - }), - ); + ...currentItem, + tools: [...currentTools, toolCallDisplay], + } as Partial<Omit<HistoryItem, 'id'>>; + }, + ); + } } else if (event.type === ServerGeminiEventType.ToolCallResponse) { const status = event.value.error ? ToolCallStatus.Error @@ -378,21 +299,20 @@ export const useGeminiStream = ( confirmationDetails, ); setStreamingState(StreamingState.WaitingForConfirmation); - return; + return; // Wait for user confirmation } - } + } // End stream loop setStreamingState(StreamingState.Idle); } catch (error: unknown) { if (!isNodeError(error) || error.name !== 'AbortError') { console.error('Error processing stream or executing tool:', error); - addHistoryItem( - setHistory, + addItem( { type: 'error', - text: `[Error: ${getErrorMessage(error)}]`, + text: `[Stream Error: ${getErrorMessage(error)}]`, }, - getNextMessageId(userMessageTimestamp), + userMessageTimestamp, ); } setStreamingState(StreamingState.Idle); @@ -400,28 +320,35 @@ export const useGeminiStream = ( abortControllerRef.current = null; } + // --- Helper functions for updating tool UI --- + function updateConfirmingFunctionStatusUI( callId: string, confirmationDetails: ToolCallConfirmationDetails | undefined, ) { - setHistory((prevHistory) => - prevHistory.map((item) => { - if (item.id === currentToolGroupId && item.type === 'tool_group') { - return { - ...item, - tools: item.tools.map((tool) => - tool.callId === callId - ? { - ...tool, - status: ToolCallStatus.Confirming, - confirmationDetails, - } - : tool, - ), - }; + if (currentToolGroupMessageId === null) return; + updateItem( + currentToolGroupMessageId, + (currentItem: HistoryItem): Partial<Omit<HistoryItem, 'id'>> => { + if (currentItem?.type !== 'tool_group') { + console.error( + `Attempted to update non-tool-group item ${currentItem?.id} status.`, + ); + return currentItem as Partial<Omit<HistoryItem, 'id'>>; } - return item; - }), + return { + ...currentItem, + tools: (currentItem.tools || []).map((tool) => + tool.callId === callId + ? { + ...tool, + status: ToolCallStatus.Confirming, + confirmationDetails, + } + : tool, + ), + } as Partial<Omit<HistoryItem, 'id'>>; + }, ); } @@ -429,29 +356,35 @@ export const useGeminiStream = ( toolResponse: ToolCallResponseInfo, status: ToolCallStatus, ) { - setHistory((prevHistory) => - prevHistory.map((item) => { - if (item.id === currentToolGroupId && item.type === 'tool_group') { - return { - ...item, - tools: item.tools.map((tool) => { - if (tool.callId === toolResponse.callId) { - return { - ...tool, - status, - resultDisplay: toolResponse.resultDisplay, - }; - } else { - return tool; - } - }), - }; + if (currentToolGroupMessageId === null) return; + updateItem( + currentToolGroupMessageId, + (currentItem: HistoryItem): Partial<Omit<HistoryItem, 'id'>> => { + if (currentItem?.type !== 'tool_group') { + console.error( + `Attempted to update non-tool-group item ${currentItem?.id} response.`, + ); + return currentItem as Partial<Omit<HistoryItem, 'id'>>; } - return item; - }), + return { + ...currentItem, + tools: (currentItem.tools || []).map((tool) => { + if (tool.callId === toolResponse.callId) { + return { + ...tool, + status, + resultDisplay: toolResponse.resultDisplay, + }; + } else { + return tool; + } + }), + } as Partial<Omit<HistoryItem, 'id'>>; + }, ); } + // Wires the server-side confirmation callback to UI updates and state changes function wireConfirmationSubmission( confirmationDetails: ServerToolCallConfirmationDetails, ): ToolCallConfirmationDetails { @@ -460,6 +393,7 @@ export const useGeminiStream = ( const resubmittingConfirm = async ( outcome: ToolConfirmationOutcome, ) => { + // Call the original server-side handler first originalConfirmationDetails.onConfirm(outcome); if (outcome === ToolConfirmationOutcome.Cancel) { @@ -480,41 +414,18 @@ export const useGeminiStream = ( response: { error: 'User rejected function call.' }, }, }; - const responseInfo: ToolCallResponseInfo = { callId: request.callId, responsePart: functionResponse, resultDisplay, - error: undefined, + error: new Error('User rejected function call.'), }; - + // Update UI to show cancellation/error updateFunctionResponseUI(responseInfo, ToolCallStatus.Error); setStreamingState(StreamingState.Idle); } else { - const tool = toolRegistry.getTool(request.name); - if (!tool) { - throw new Error( - `Tool "${request.name}" not found or is not registered.`, - ); - } - const result = await tool.execute(request.args); - const functionResponse: Part = { - functionResponse: { - name: request.name, - id: request.callId, - response: { output: result.llmContent }, - }, - }; - - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responsePart: functionResponse, - resultDisplay: result.returnDisplay, - error: undefined, - }; - updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); - setStreamingState(StreamingState.Idle); - await submitQuery(functionResponse); + // If accepted, set state back to Responding to wait for server execution/response + setStreamingState(StreamingState.Responding); } }; @@ -524,21 +435,19 @@ export const useGeminiStream = ( }; } }, - // Dependencies need careful review [ streamingState, - setHistory, config, - getNextMessageId, updateGeminiMessage, handleSlashCommand, handleShellCommand, - // handleAtCommand is implicitly included via its direct call - setDebugMessage, // Added dependency for handleAtCommand & passthrough - setStreamingState, // Added dependency for handlePassthroughCommand - updateAndAddGeminiMessageContent, + setDebugMessage, + setStreamingState, + addItem, + updateItem, setShowHelp, toolRegistry, + setInitError, ], ); |
