diff options
Diffstat (limited to 'packages/cli/src/ui/hooks/useGeminiStream.ts')
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 161 |
1 files changed, 128 insertions, 33 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 585554ee..62851019 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -17,8 +17,18 @@ import { Config, ToolCallConfirmationDetails, ToolCallResponseInfo, + ServerToolCallConfirmationDetails, + ToolConfirmationOutcome, + ToolResultDisplay, + ToolEditConfirmationDetails, + ToolExecuteConfirmationDetails, } from '@gemini-code/server'; -import type { Chat, PartListUnion, FunctionDeclaration } from '@google/genai'; +import { + type Chat, + type PartListUnion, + type FunctionDeclaration, + type Part, +} from '@google/genai'; import { HistoryItem, IndividualToolCallDisplay, @@ -286,36 +296,24 @@ export const useGeminiStream = ( }), ); } else if (event.type === ServerGeminiEventType.ToolCallResponse) { - updateFunctionResponseUI(event.value); + const status = event.value.error + ? ToolCallStatus.Error + : ToolCallStatus.Success; + updateFunctionResponseUI(event.value, status); } else if ( event.type === ServerGeminiEventType.ToolCallConfirmation ) { - setHistory((prevHistory) => - prevHistory.map((item) => { - if ( - item.id === currentToolGroupId && - item.type === 'tool_group' - ) { - return { - ...item, - tools: item.tools.map((tool) => - tool.callId === event.value.request.callId - ? { - ...tool, - status: ToolCallStatus.Confirming, - confirmationDetails: event.value.details, - } - : tool, - ), - }; - } - return item; - }), + const confirmationDetails = wireConfirmationSubmission(event.value); + updateConfirmingFunctionStatusUI( + event.value.request.callId, + confirmationDetails, ); setStreamingState(StreamingState.WaitingForConfirmation); return; } } + + setStreamingState(StreamingState.Idle); } catch (error: unknown) { if (!isNodeError(error) || error.name !== 'AbortError') { console.error('Error processing stream or executing tool:', error); @@ -328,16 +326,40 @@ export const useGeminiStream = ( getNextMessageId(userMessageTimestamp), ); } + setStreamingState(StreamingState.Idle); } finally { abortControllerRef.current = null; - // Only set to Idle if not waiting for confirmation. - // Passthrough commands handle their own Idle transition. - if (streamingState !== StreamingState.WaitingForConfirmation) { - setStreamingState(StreamingState.Idle); - } } - function updateFunctionResponseUI(toolResponse: ToolCallResponseInfo) { + function updateConfirmingFunctionStatusUI( + callId: string, + confirmationDetails: ToolCallConfirmationDetails | undefined, + ) { + setHistory((prevHistory) => + prevHistory.map((item) => { + if (item.id === currentToolGroupId && item.type === 'tool_group') { + return { + ...item, + tools: item.tools.map((tool) => + tool.callId === callId + ? { + ...tool, + status: ToolCallStatus.Confirming, + confirmationDetails, + } + : tool, + ), + }; + } + return item; + }), + ); + } + + function updateFunctionResponseUI( + toolResponse: ToolCallResponseInfo, + status: ToolCallStatus, + ) { setHistory((prevHistory) => prevHistory.map((item) => { if (item.id === currentToolGroupId && item.type === 'tool_group') { @@ -347,10 +369,7 @@ export const useGeminiStream = ( if (tool.callId === toolResponse.callId) { return { ...tool, - // TODO: Do we surface the error here? - status: toolResponse.error - ? ToolCallStatus.Error - : ToolCallStatus.Success, + status, resultDisplay: toolResponse.resultDisplay, }; } else { @@ -363,6 +382,82 @@ export const useGeminiStream = ( }), ); } + + function wireConfirmationSubmission( + confirmationDetails: ServerToolCallConfirmationDetails, + ): ToolCallConfirmationDetails { + const originalConfirmationDetails = confirmationDetails.details; + const request = confirmationDetails.request; + const resubmittingConfirm = async ( + outcome: ToolConfirmationOutcome, + ) => { + originalConfirmationDetails.onConfirm(outcome); + + // Reset streaming state since confirmation has been chosen. + setStreamingState(StreamingState.Idle); + + if (outcome === ToolConfirmationOutcome.Cancel) { + let resultDisplay: ToolResultDisplay | undefined; + if ('fileDiff' in originalConfirmationDetails) { + resultDisplay = { + fileDiff: ( + originalConfirmationDetails as ToolEditConfirmationDetails + ).fileDiff, + }; + } else { + resultDisplay = `~~${(originalConfirmationDetails as ToolExecuteConfirmationDetails).command}~~`; + } + const functionResponse: Part = { + functionResponse: { + id: request.callId, + name: request.name, + response: { error: 'User rejected function call.' }, + }, + }; + + const responseInfo: ToolCallResponseInfo = { + callId: request.callId, + responsePart: functionResponse, + resultDisplay, + error: undefined, + }; + + updateFunctionResponseUI(responseInfo, ToolCallStatus.Error); + + await submitQuery(functionResponse); + } else { + const tool = toolRegistry.getTool(request.name); + if (!tool) { + throw new Error( + `Tool "${request.name}" not found or is not registered.`, + ); + } + const result = await tool.execute(request.args); + const functionResponse: Part = { + functionResponse: { + name: request.name, + id: request.callId, + response: { output: result.llmContent }, + }, + }; + + const responseInfo: ToolCallResponseInfo = { + callId: request.callId, + responsePart: functionResponse, + resultDisplay: result.returnDisplay, + error: undefined, + }; + updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); + + await submitQuery(functionResponse); + } + }; + + return { + ...originalConfirmationDetails, + onConfirm: resubmittingConfirm, + }; + } }, // Dependencies need careful review - including updateGeminiMessage [ |
