diff options
Diffstat (limited to 'packages/cli/src/ui/hooks')
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 222 |
1 files changed, 70 insertions, 152 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 21a9f508..585554ee 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -15,17 +15,16 @@ import { isNodeError, ToolResult, Config, + ToolCallConfirmationDetails, + ToolCallResponseInfo, } from '@gemini-code/server'; import type { Chat, PartListUnion, FunctionDeclaration } from '@google/genai'; -// Import CLI types import { HistoryItem, IndividualToolCallDisplay, ToolCallStatus, } from '../types.js'; -import { Tool } from '../../tools/tools.js'; // CLI Tool definition import { StreamingState } from '../../core/gemini-stream.js'; -// Import CLI tool registry import { toolRegistry } from '../../tools/tool-registry.js'; const addHistoryItem = ( @@ -112,7 +111,7 @@ export const useGeminiStream = ( // This just clears the *UI* history, not the model history. // TODO: add a slash command for that. setDebugMessage('Clearing terminal.'); - setHistory((prevHistory) => []); + setHistory((_) => []); return; } else if (config.getPassthroughCommands().includes(maybeCommand)) { // Execute and capture output @@ -188,14 +187,7 @@ export const useGeminiStream = ( const signal = abortControllerRef.current.signal; // Get ServerTool descriptions for the server call - const serverTools: ServerTool[] = toolRegistry - .getAllTools() - .map((cliTool: Tool) => ({ - name: cliTool.name, - schema: cliTool.schema, - execute: (args: Record<string, unknown>) => - cliTool.execute(args as ToolArgs), // Pass execution - })); + const serverTools: ServerTool[] = toolRegistry.getAllTools(); const stream = client.sendMessageStream( chat, @@ -257,11 +249,18 @@ export const useGeminiStream = ( ); } + let description: string; + try { + description = cliTool.getDescription(args); + } catch (e) { + description = `Error: Unable to get description: ${getErrorMessage(e)}`; + } + // Create the UI display object matching IndividualToolCallDisplay const toolCallDisplay: IndividualToolCallDisplay = { callId, name, - description: cliTool.getDescription(args as ToolArgs), + description, status: ToolCallStatus.Pending, resultDisplay: undefined, confirmationDetails: undefined, @@ -286,143 +285,35 @@ export const useGeminiStream = ( return item; }), ); - - // --- Tool Execution & Confirmation Logic --- - const confirmationDetails = await cliTool.shouldConfirmExecute( - args as ToolArgs, + } else if (event.type === ServerGeminiEventType.ToolCallResponse) { + updateFunctionResponseUI(event.value); + } else if ( + event.type === ServerGeminiEventType.ToolCallConfirmation + ) { + setHistory((prevHistory) => + prevHistory.map((item) => { + if ( + item.id === currentToolGroupId && + item.type === 'tool_group' + ) { + return { + ...item, + tools: item.tools.map((tool) => + tool.callId === event.value.request.callId + ? { + ...tool, + status: ToolCallStatus.Confirming, + confirmationDetails: event.value.details, + } + : tool, + ), + }; + } + return item; + }), ); - - if (confirmationDetails) { - setHistory((prevHistory) => - prevHistory.map((item) => { - if ( - item.id === currentToolGroupId && - item.type === 'tool_group' - ) { - return { - ...item, - tools: item.tools.map((tool) => - tool.callId === callId - ? { - ...tool, - status: ToolCallStatus.Confirming, - confirmationDetails, - } - : tool, - ), - }; - } - return item; - }), - ); - setStreamingState(StreamingState.WaitingForConfirmation); - return; - } - - try { - setHistory((prevHistory) => - prevHistory.map((item) => { - if ( - item.id === currentToolGroupId && - item.type === 'tool_group' - ) { - return { - ...item, - tools: item.tools.map((tool) => - tool.callId === callId - ? { - ...tool, - status: - tool.status === ToolCallStatus.Error - ? ToolCallStatus.Error - : ToolCallStatus.Invoked, - } - : tool, - ), - }; - } - return item; - }), - ); - - const result: ToolResult = await cliTool.execute( - args as ToolArgs, - ); - const resultPart = { - functionResponse: { - name, - id: callId, - response: { output: result.llmContent }, - }, - }; - - setHistory((prevHistory) => - prevHistory.map((item) => { - if ( - item.id === currentToolGroupId && - item.type === 'tool_group' - ) { - return { - ...item, - tools: item.tools.map((tool) => - tool.callId === callId - ? { - ...tool, - status: - tool.status === ToolCallStatus.Error - ? ToolCallStatus.Error - : ToolCallStatus.Success, - resultDisplay: result.returnDisplay, - } - : tool, - ), - }; - } - return item; - }), - ); - - // Execute the function and continue the stream - await submitQuery(resultPart); - return; - } catch (execError: unknown) { - const error = new Error( - `Tool execution failed: ${execError instanceof Error ? execError.message : String(execError)}`, - ); - const errorPart = { - functionResponse: { - name, - id: callId, - response: { - error: `Tool execution failed: ${error.message}`, - }, - }, - }; - setHistory((prevHistory) => - prevHistory.map((item) => { - if ( - item.id === currentToolGroupId && - item.type === 'tool_group' - ) { - return { - ...item, - tools: item.tools.map((tool) => - tool.callId === callId - ? { - ...tool, - status: ToolCallStatus.Error, - resultDisplay: `Error: ${error.message}`, - } - : tool, - ), - }; - } - return item; - }), - ); - await submitQuery(errorPart); - return; - } + setStreamingState(StreamingState.WaitingForConfirmation); + return; } } } catch (error: unknown) { @@ -445,6 +336,33 @@ export const useGeminiStream = ( setStreamingState(StreamingState.Idle); } } + + function updateFunctionResponseUI(toolResponse: ToolCallResponseInfo) { + setHistory((prevHistory) => + prevHistory.map((item) => { + if (item.id === currentToolGroupId && item.type === 'tool_group') { + return { + ...item, + tools: item.tools.map((tool) => { + if (tool.callId === toolResponse.callId) { + return { + ...tool, + // TODO: Do we surface the error here? + status: toolResponse.error + ? ToolCallStatus.Error + : ToolCallStatus.Success, + resultDisplay: toolResponse.resultDisplay, + }; + } else { + return tool; + } + }), + }; + } + return item; + }), + ); + } }, // Dependencies need careful review - including updateGeminiMessage [ @@ -464,8 +382,8 @@ export const useGeminiStream = ( interface ServerTool { name: string; schema: FunctionDeclaration; + shouldConfirmExecute( + params: Record<string, unknown>, + ): Promise<ToolCallConfirmationDetails | false>; execute(params: Record<string, unknown>): Promise<ToolResult>; } - -// Define a more specific type for tool arguments to replace 'any' -type ToolArgs = Record<string, unknown>; |
