diff options
| author | Evan Senter <[email protected]> | 2025-04-19 19:45:42 +0100 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-04-19 19:45:42 +0100 |
| commit | 3fce6cea27d3e6129d6c06e528b62e1b11bf7094 (patch) | |
| tree | 244b8e9ab94f902d65d4bda8739a6538e377ed17 /packages/cli/src/ui/hooks/useGeminiStream.ts | |
| parent | 0c9e1ef61be7db53e6e73b7208b649cd8cbed6c3 (diff) | |
Starting to modularize into separate cli / server packages. (#55)
* Starting to move a lot of code into packages/server
* More of the massive refactor, builds and runs, some issues though.
* Fixing outstanding issue with double messages.
* Fixing a minor UI issue.
* Fixing the build post-merge.
* Running formatting.
* Addressing comments.
Diffstat (limited to 'packages/cli/src/ui/hooks/useGeminiStream.ts')
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 401 |
1 files changed, 314 insertions, 87 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 8cbb5f51..56203179 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -4,20 +4,30 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { exec } from 'child_process'; +import { exec as _exec } from 'child_process'; import { useState, useRef, useCallback, useEffect } from 'react'; import { useInput } from 'ink'; -import { GeminiClient } from '../../core/gemini-client.js'; -import { type Chat, type PartListUnion } from '@google/genai'; -import { HistoryItem } from '../types.js'; +// Import server-side client and types import { - processGeminiStream, - StreamingState, -} from '../../core/gemini-stream.js'; -import { globalConfig } from '../../config/config.js'; -import { getErrorMessage, isNodeError } from '../../utils/errors.js'; + GeminiClient, + GeminiEventType as ServerGeminiEventType, // Rename to avoid conflict + getErrorMessage, + isNodeError, + ToolResult, +} from '@gemini-code/server'; +import type { Chat, PartListUnion, FunctionDeclaration } from '@google/genai'; +// Import CLI types +import { + HistoryItem, + IndividualToolCallDisplay, + ToolCallStatus, +} from '../types.js'; +import { Tool } from '../../tools/tools.js'; // CLI Tool definition +import { StreamingState } from '../../core/gemini-stream.js'; +// Import CLI tool registry +import { toolRegistry } from '../../tools/tool-registry.js'; -const allowlistedCommands = ['ls']; // TODO: make this configurable +const _allowlistedCommands = ['ls']; // Prefix with underscore since it's unused const addHistoryItem = ( setHistory: React.Dispatch<React.SetStateAction<HistoryItem[]>>, @@ -30,32 +40,36 @@ const addHistoryItem = ( ]); }; +// Hook now accepts apiKey and model export const useGeminiStream = ( setHistory: React.Dispatch<React.SetStateAction<HistoryItem[]>>, + apiKey: string, + model: string, ) => { const [streamingState, setStreamingState] = useState<StreamingState>( StreamingState.Idle, ); const [initError, setInitError] = useState<string | null>(null); const abortControllerRef = useRef<AbortController | null>(null); - const currentToolGroupIdRef = useRef<number | null>(null); const chatSessionRef = useRef<Chat | null>(null); const geminiClientRef = useRef<GeminiClient | null>(null); const messageIdCounterRef = useRef(0); + const currentGeminiMessageIdRef = useRef<number | null>(null); - // Initialize Client Effect (remains the same) + // Initialize Client Effect - uses props now useEffect(() => { setInitError(null); if (!geminiClientRef.current) { try { - geminiClientRef.current = new GeminiClient(globalConfig); + geminiClientRef.current = new GeminiClient(apiKey, model); } catch (error: unknown) { setInitError( `Failed to initialize client: ${getErrorMessage(error) || 'Unknown error'}`, ); } } - }, []); + // Dependency array includes apiKey and model now + }, [apiKey, model]); // Input Handling Effect (remains the same) useInput((input, key) => { @@ -70,17 +84,25 @@ export const useGeminiStream = ( return baseTimestamp + messageIdCounterRef.current; }, []); - // Submit Query Callback (updated to call processGeminiStream) + // Helper function to update Gemini message content + const updateGeminiMessage = useCallback( + (messageId: number, newContent: string) => { + setHistory((prevHistory) => + prevHistory.map((item) => + item.id === messageId && item.type === 'gemini' + ? { ...item, text: newContent } + : item, + ), + ); + }, + [setHistory], + ); + + // Improved submit query function const submitQuery = useCallback( async (query: PartListUnion) => { - if (streamingState === StreamingState.Responding) { - // No-op if already going. - return; - } - - if (typeof query === 'string' && query.toString().trim().length === 0) { - return; - } + if (streamingState === StreamingState.Responding) return; + if (typeof query === 'string' && query.trim().length === 0) return; const userMessageTimestamp = Date.now(); const client = geminiClientRef.current; @@ -90,101 +112,306 @@ export const useGeminiStream = ( } if (!chatSessionRef.current) { - chatSessionRef.current = await client.startChat(); + try { + // Use getFunctionDeclarations for startChat + const toolSchemas = toolRegistry.getFunctionDeclarations(); + chatSessionRef.current = await client.startChat(toolSchemas); + } catch (err: unknown) { + setInitError(`Failed to start chat: ${getErrorMessage(err)}`); + setStreamingState(StreamingState.Idle); + return; + } } - // Reset state setStreamingState(StreamingState.Responding); setInitError(null); - currentToolGroupIdRef.current = null; - messageIdCounterRef.current = 0; + messageIdCounterRef.current = 0; // Reset counter for new submission const chat = chatSessionRef.current; + let currentToolGroupId: number | null = null; + + // For function responses, we don't need to add a user message + if (typeof query === 'string') { + // Only add user message for string queries, not for function responses + addHistoryItem( + setHistory, + { type: 'user', text: query }, + userMessageTimestamp, + ); + } try { - // Add user message - if (typeof query === 'string') { - const trimmedQuery = query.toString(); - addHistoryItem( - setHistory, - { type: 'user', text: trimmedQuery }, - userMessageTimestamp, - ); + abortControllerRef.current = new AbortController(); + const signal = abortControllerRef.current.signal; + + // Get ServerTool descriptions for the server call + const serverTools: ServerTool[] = toolRegistry + .getAllTools() + .map((cliTool: Tool) => ({ + name: cliTool.name, + schema: cliTool.schema, + execute: (args: Record<string, unknown>) => + cliTool.execute(args as ToolArgs), // Pass execution + })); + + const stream = client.sendMessageStream( + chat, + query, + serverTools, + signal, + ); + + // Process the stream events from the server logic + let currentGeminiText = ''; // To accumulate message content + let hasInitialGeminiResponse = false; + + for await (const event of stream) { + if (signal.aborted) break; + + if (event.type === ServerGeminiEventType.Content) { + // For content events, accumulate the text and update an existing message or create a new one + currentGeminiText += event.value; + + if (!hasInitialGeminiResponse) { + // Create a new Gemini message if this is the first content event + hasInitialGeminiResponse = true; + const eventTimestamp = getNextMessageId(userMessageTimestamp); + currentGeminiMessageIdRef.current = eventTimestamp; - const maybeCommand = trimmedQuery.split(/\s+/)[0]; - if (allowlistedCommands.includes(maybeCommand)) { - exec(trimmedQuery, (error, stdout) => { - const timestamp = getNextMessageId(userMessageTimestamp); - // TODO: handle stderr, error addHistoryItem( setHistory, - { type: 'info', text: stdout }, - timestamp, + { type: 'gemini', text: currentGeminiText }, + eventTimestamp, ); - }); - return; - } - } else if ( - // HACK to detect errored function responses. - typeof query === 'object' && - query !== null && - !Array.isArray(query) && // Ensure it's a single Part object - 'functionResponse' in query && // Check if it's a function response Part - query.functionResponse?.response && // Check if response object exists - 'error' in query.functionResponse.response // Check specifically for the 'error' key - ) { - const history = chat.getHistory(); - history.push({ role: 'user', parts: [query] }); - return; - } + } else if (currentGeminiMessageIdRef.current !== null) { + // Update the existing message with accumulated content + updateGeminiMessage( + currentGeminiMessageIdRef.current, + currentGeminiText, + ); + } + } else if (event.type === ServerGeminiEventType.ToolCallRequest) { + // Reset the Gemini message tracking for the next response + currentGeminiText = ''; + hasInitialGeminiResponse = false; + currentGeminiMessageIdRef.current = null; - // Prepare for streaming - abortControllerRef.current = new AbortController(); - const signal = abortControllerRef.current.signal; + const { callId, name, args } = event.value; + + const cliTool = toolRegistry.getTool(name); // Get the full CLI tool + if (!cliTool) { + console.error(`CLI Tool "${name}" not found!`); + continue; + } - // --- Delegate to Stream Processor --- + if (currentToolGroupId === null) { + currentToolGroupId = getNextMessageId(userMessageTimestamp); + // Add explicit cast to Omit<HistoryItem, 'id'> + addHistoryItem( + setHistory, + { type: 'tool_group', tools: [] } as Omit<HistoryItem, 'id'>, + currentToolGroupId, + ); + } - const stream = client.sendMessageStream(chat, query, signal); + // Create the UI display object matching IndividualToolCallDisplay + const toolCallDisplay: IndividualToolCallDisplay = { + callId, + name, + description: cliTool.getDescription(args as ToolArgs), + status: ToolCallStatus.Pending, + resultDisplay: undefined, + confirmationDetails: undefined, + }; - const addHistoryItemFromStream = ( - itemData: Omit<HistoryItem, 'id'>, - id: number, - ) => { - addHistoryItem(setHistory, itemData, id); - }; - const getStreamMessageId = () => getNextMessageId(userMessageTimestamp); + // Add pending tool call to the UI history group + setHistory((prevHistory) => + prevHistory.map((item) => { + if ( + item.id === currentToolGroupId && + item.type === 'tool_group' + ) { + // Ensure item.tools exists and is an array before spreading + const currentTools = Array.isArray(item.tools) + ? item.tools + : []; + return { + ...item, + tools: [...currentTools, toolCallDisplay], // Add the complete display object + }; + } + return item; + }), + ); - // Call the renamed processor function - await processGeminiStream({ - stream, - signal, - setHistory, - submitQuery, - getNextMessageId: getStreamMessageId, - addHistoryItem: addHistoryItemFromStream, - currentToolGroupIdRef, - }); + // --- Tool Execution & Confirmation Logic --- + const confirmationDetails = await cliTool.shouldConfirmExecute( + args as ToolArgs, + ); + + if (confirmationDetails) { + setHistory((prevHistory) => + prevHistory.map((item) => { + if ( + item.id === currentToolGroupId && + item.type === 'tool_group' + ) { + return { + ...item, + tools: item.tools.map((tool) => + tool.callId === callId + ? { + ...tool, + status: ToolCallStatus.Confirming, + confirmationDetails, + } + : tool, + ), + }; + } + return item; + }), + ); + setStreamingState(StreamingState.WaitingForConfirmation); + return; + } + + try { + setHistory((prevHistory) => + prevHistory.map((item) => { + if ( + item.id === currentToolGroupId && + item.type === 'tool_group' + ) { + return { + ...item, + tools: item.tools.map((tool) => + tool.callId === callId + ? { ...tool, status: ToolCallStatus.Invoked } + : tool, + ), + }; + } + return item; + }), + ); + + const result: ToolResult = await cliTool.execute( + args as ToolArgs, + ); + const resultPart = { + functionResponse: { + name, + id: callId, + response: { output: result.llmContent }, + }, + }; + + setHistory((prevHistory) => + prevHistory.map((item) => { + if ( + item.id === currentToolGroupId && + item.type === 'tool_group' + ) { + return { + ...item, + tools: item.tools.map((tool) => + tool.callId === callId + ? { + ...tool, + status: ToolCallStatus.Success, + resultDisplay: result.returnDisplay, + } + : tool, + ), + }; + } + return item; + }), + ); + + // Execute the function and continue the stream + await submitQuery(resultPart); + return; + } catch (execError: unknown) { + const error = new Error( + `Tool execution failed: ${execError instanceof Error ? execError.message : String(execError)}`, + ); + const errorPart = { + functionResponse: { + name, + id: callId, + response: { + error: `Tool execution failed: ${error.message}`, + }, + }, + }; + setHistory((prevHistory) => + prevHistory.map((item) => { + if ( + item.id === currentToolGroupId && + item.type === 'tool_group' + ) { + return { + ...item, + tools: item.tools.map((tool) => + tool.callId === callId + ? { + ...tool, + status: ToolCallStatus.Error, + resultDisplay: `Error: ${error.message}`, + } + : tool, + ), + }; + } + return item; + }), + ); + await submitQuery(errorPart); + return; + } + } + } } catch (error: unknown) { - // (Error handling for stream initiation remains the same) - console.error('Error initiating stream:', error); if (!isNodeError(error) || error.name !== 'AbortError') { - // Use historyUpdater's function potentially? Or keep addHistoryItem here? - // Keeping addHistoryItem here for direct errors from this scope. + console.error('Error processing stream or executing tool:', error); addHistoryItem( setHistory, { type: 'error', - text: `[Error starting stream: ${getErrorMessage(error)}]`, + text: `[Error: ${getErrorMessage(error)}]`, }, getNextMessageId(userMessageTimestamp), ); } } finally { abortControllerRef.current = null; - setStreamingState(StreamingState.Idle); + // Only set to Idle if not waiting for confirmation + if (streamingState !== StreamingState.WaitingForConfirmation) { + setStreamingState(StreamingState.Idle); + } } }, - [setStreamingState, setHistory, initError, getNextMessageId], + // Dependencies need careful review - including updateGeminiMessage + [ + streamingState, + setHistory, + apiKey, + model, + getNextMessageId, + updateGeminiMessage, + ], ); return { streamingState, submitQuery, initError }; }; + +// Define ServerTool interface here if not importing from server (circular dep issue?) +interface ServerTool { + name: string; + schema: FunctionDeclaration; + execute(params: Record<string, unknown>): Promise<ToolResult>; +} + +// Define a more specific type for tool arguments to replace 'any' +type ToolArgs = Record<string, unknown>; |
