diff options
Diffstat (limited to 'packages/cli/src')
| -rw-r--r-- | packages/cli/src/ui/App.tsx | 17 | ||||
| -rw-r--r-- | packages/cli/src/ui/components/messages/ToolGroupMessage.tsx | 12 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 311 |
3 files changed, 97 insertions, 243 deletions
diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx index 42613530..74c1ea5d 100644 --- a/packages/cli/src/ui/App.tsx +++ b/packages/cli/src/ui/App.tsx @@ -134,7 +134,7 @@ export const App = ({ cliVersion, ); - const { streamingState, submitQuery, initError, pendingHistoryItem } = + const { streamingState, submitQuery, initError, pendingHistoryItems } = useGeminiStream( addItem, refreshStatic, @@ -209,7 +209,7 @@ export const App = ({ }, [terminalHeight, footerHeight]); useEffect(() => { - if (!pendingHistoryItem) { + if (!pendingHistoryItems.length) { return; } @@ -223,7 +223,7 @@ export const App = ({ if (pendingItemDimensions.height > availableTerminalHeight) { setStaticNeedsRefresh(true); } - }, [pendingHistoryItem, availableTerminalHeight, streamingState]); + }, [pendingHistoryItems.length, availableTerminalHeight, streamingState]); useEffect(() => { if (streamingState === StreamingState.Idle && staticNeedsRefresh) { @@ -264,17 +264,18 @@ export const App = ({ > {(item) => item} </Static> - {pendingHistoryItem && ( - <Box ref={pendingHistoryItemRef}> + <Box ref={pendingHistoryItemRef}> + {pendingHistoryItems.map((item, i) => ( <HistoryItemDisplay + key={i} availableTerminalHeight={availableTerminalHeight} // TODO(taehykim): It seems like references to ids aren't necessary in // HistoryItemDisplay. Refactor later. Use a fake id for now. - item={{ ...pendingHistoryItem, id: 0 }} + item={{ ...item, id: 0 }} isPending={true} /> - </Box> - )} + ))} + </Box> {showHelp && <Help commands={slashCommands} />} <Box flexDirection="column" ref={mainControlsRef}> diff --git a/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx b/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx index d0ad1c5f..4b2c7dfe 100644 --- a/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolGroupMessage.tsx @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import React from 'react'; +import React, { useMemo } from 'react'; import { Box } from 'ink'; import { IndividualToolCallDisplay, ToolCallStatus } from '../../types.js'; import { ToolMessage } from './ToolMessage.js'; @@ -19,7 +19,6 @@ interface ToolGroupMessageProps { // Main component renders the border and maps the tools using ToolMessage export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({ - groupId, toolCalls, availableTerminalHeight, }) => { @@ -30,9 +29,13 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({ const staticHeight = /* border */ 2 + /* marginBottom */ 1; + const toolAwaitingApproval = useMemo( + () => toolCalls.find((tc) => tc.status === ToolCallStatus.Confirming), + [toolCalls], + ); + return ( <Box - key={groupId} flexDirection="column" borderStyle="round" /* @@ -48,7 +51,7 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({ marginBottom={1} > {toolCalls.map((tool) => ( - <Box key={groupId + '-' + tool.callId} flexDirection="column"> + <Box key={tool.callId} flexDirection="column"> <ToolMessage key={tool.callId} callId={tool.callId} @@ -60,6 +63,7 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({ availableTerminalHeight={availableTerminalHeight - staticHeight} /> {tool.status === ToolCallStatus.Confirming && + tool.callId === toolAwaitingApproval?.callId && tool.confirmationDetails && ( <ToolConfirmationMessage confirmationDetails={tool.confirmationDetails} diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index de7980d5..324a4ffa 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -4,34 +4,28 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { useState, useRef, useCallback, useEffect } from 'react'; +import { useState, useRef, useCallback, useEffect, useMemo } from 'react'; import { useInput } from 'ink'; import { GeminiClient, GeminiEventType as ServerGeminiEventType, ServerGeminiStreamEvent as GeminiEvent, ServerGeminiContentEvent as ContentEvent, - ServerGeminiToolCallRequestEvent as ToolCallRequestEvent, - ServerGeminiToolCallResponseEvent as ToolCallResponseEvent, - ServerGeminiToolCallConfirmationEvent as ToolCallConfirmationEvent, ServerGeminiErrorEvent as ErrorEvent, getErrorMessage, isNodeError, Config, MessageSenderType, ServerToolCallConfirmationDetails, - ToolCallConfirmationDetails, ToolCallResponseInfo, - ToolConfirmationOutcome, ToolEditConfirmationDetails, ToolExecuteConfirmationDetails, ToolResultDisplay, - partListUnionToString, + ToolCallRequestInfo, } from '@gemini-code/server'; import { type Chat, type PartListUnion, type Part } from '@google/genai'; import { StreamingState, - IndividualToolCallDisplay, ToolCallStatus, HistoryItemWithoutId, HistoryItemToolGroup, @@ -44,6 +38,7 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js'; import { useStateAndRef } from './useStateAndRef.js'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { useLogger } from './useLogger.js'; +import { useToolScheduler, mapToDisplay } from './useToolScheduler.js'; enum StreamProcessingStatus { Completed, @@ -65,7 +60,6 @@ export const useGeminiStream = ( handleSlashCommand: (cmd: PartListUnion) => boolean, shellModeActive: boolean, ) => { - const toolRegistry = config.getToolRegistry(); const [initError, setInitError] = useState<string | null>(null); const abortControllerRef = useRef<AbortController | null>(null); const chatSessionRef = useRef<Chat | null>(null); @@ -74,6 +68,25 @@ export const useGeminiStream = ( const [pendingHistoryItemRef, setPendingHistoryItem] = useStateAndRef<HistoryItemWithoutId | null>(null); const logger = useLogger(); + const [toolCalls, schedule, cancel] = useToolScheduler((tools) => { + if (tools.length) { + addItem(mapToDisplay(tools), Date.now()); + submitQuery( + tools + .filter( + (t) => + t.status === 'error' || + t.status === 'cancelled' || + t.status === 'success', + ) + .map((t) => t.response.responsePart), + ); + } + }, config); + const pendingToolCalls = useMemo( + () => (toolCalls.length ? mapToDisplay(toolCalls) : undefined), + [toolCalls], + ); const onExec = useCallback(async (done: Promise<void>) => { setIsResponding(true); @@ -104,6 +117,7 @@ export const useGeminiStream = ( useInput((_input, key) => { if (streamingState !== StreamingState.Idle && key.escape) { abortControllerRef.current?.abort(); + cancel(); } }); @@ -215,157 +229,48 @@ export const useGeminiStream = ( ); }; - const updateConfirmingFunctionStatusUI = ( - callId: string, - confirmationDetails: ToolCallConfirmationDetails | undefined, - ) => { - setPendingHistoryItem((item) => - item?.type === 'tool_group' - ? { - ...item, - tools: item.tools.map((tool) => - tool.callId === callId - ? { - ...tool, - status: ToolCallStatus.Confirming, - confirmationDetails, - } - : tool, - ), - } - : item, - ); - }; - - const wireConfirmationSubmission = ( - confirmationDetails: ServerToolCallConfirmationDetails, - ): ToolCallConfirmationDetails => { - const originalConfirmationDetails = confirmationDetails.details; - const request = confirmationDetails.request; - const resubmittingConfirm = async (outcome: ToolConfirmationOutcome) => { - originalConfirmationDetails.onConfirm(outcome); - if (pendingHistoryItemRef?.current?.type === 'tool_group') { - setPendingHistoryItem((item) => - item?.type === 'tool_group' - ? { - ...item, - tools: item.tools.map((tool) => - tool.callId === request.callId - ? { - ...tool, - confirmationDetails: undefined, - status: ToolCallStatus.Executing, - } - : tool, - ), - } - : item, - ); - refreshStatic(); - } - - if (outcome === ToolConfirmationOutcome.Cancel) { - declineToolExecution( - 'User rejected function call.', - ToolCallStatus.Error, - request, - originalConfirmationDetails, - ); - } else { - const tool = toolRegistry.getTool(request.name); - if (!tool) { - throw new Error( - `Tool "${request.name}" not found or is not registered.`, - ); - } - try { - abortControllerRef.current = new AbortController(); - const result = await tool.execute( - request.args, - abortControllerRef.current.signal, - ); - if (abortControllerRef.current.signal.aborted) { - declineToolExecution( - partListUnionToString(result.llmContent), - ToolCallStatus.Canceled, - request, - originalConfirmationDetails, - ); - return; - } - - const functionResponse: Part = { - functionResponse: { - name: request.name, - id: request.callId, - response: { output: result.llmContent }, - }, - }; - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responsePart: functionResponse, - resultDisplay: result.returnDisplay, - error: undefined, - }; - updateFunctionResponseUI(responseInfo, ToolCallStatus.Success); - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, Date.now()); - setPendingHistoryItem(null); - } - setIsResponding(false); - await submitQuery(functionResponse); // Recursive call - } finally { - if (streamingState !== StreamingState.WaitingForConfirmation) { - abortControllerRef.current = null; - } - } - } - }; - - // Extracted declineToolExecution to be part of wireConfirmationSubmission's closure - // or could be a standalone helper if more params are passed. - function declineToolExecution( - declineMessage: string, - status: ToolCallStatus, - request: ServerToolCallConfirmationDetails['request'], - originalDetails: ServerToolCallConfirmationDetails['details'], - ) { - let resultDisplay: ToolResultDisplay | undefined; - if ('fileDiff' in originalDetails) { - resultDisplay = { - fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff, - fileName: (originalDetails as ToolEditConfirmationDetails).fileName, - }; - } else { - resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`; - } - const functionResponse: Part = { - functionResponse: { - id: request.callId, - name: request.name, - response: { error: declineMessage }, - }, - }; - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responsePart: functionResponse, - resultDisplay, - error: new Error(declineMessage), + // Extracted declineToolExecution to be part of wireConfirmationSubmission's closure + // or could be a standalone helper if more params are passed. + // TODO: handle file diff result display stuff + function _declineToolExecution( + declineMessage: string, + status: ToolCallStatus, + request: ServerToolCallConfirmationDetails['request'], + originalDetails: ServerToolCallConfirmationDetails['details'], + ) { + let resultDisplay: ToolResultDisplay | undefined; + if ('fileDiff' in originalDetails) { + resultDisplay = { + fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff, + fileName: (originalDetails as ToolEditConfirmationDetails).fileName, }; - const history = chatSessionRef.current?.getHistory(); - if (history) { - history.push({ role: 'model', parts: [functionResponse] }); - } - updateFunctionResponseUI(responseInfo, status); - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, Date.now()); - setPendingHistoryItem(null); - } - setIsResponding(false); + } else { + resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`; } - - return { ...originalConfirmationDetails, onConfirm: resubmittingConfirm }; - }; + const functionResponse: Part = { + functionResponse: { + id: request.callId, + name: request.name, + response: { error: declineMessage }, + }, + }; + const responseInfo: ToolCallResponseInfo = { + callId: request.callId, + responsePart: functionResponse, + resultDisplay, + error: new Error(declineMessage), + }; + const history = chatSessionRef.current?.getHistory(); + if (history) { + history.push({ role: 'model', parts: [functionResponse] }); + } + updateFunctionResponseUI(responseInfo, status); + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, Date.now()); + setPendingHistoryItem(null); + } + setIsResponding(false); + } // --- Stream Event Handlers --- const handleContentEvent = ( @@ -419,62 +324,6 @@ export const useGeminiStream = ( return newGeminiMessageBuffer; }; - const handleToolCallRequestEvent = ( - eventValue: ToolCallRequestEvent['value'], - userMessageTimestamp: number, - ) => { - const { callId, name, args } = eventValue; - const cliTool = toolRegistry.getTool(name); - if (!cliTool) { - console.error(`CLI Tool "${name}" not found!`); - return; // Skip this event if tool is not found - } - if (pendingHistoryItemRef.current?.type !== 'tool_group') { - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); - } - setPendingHistoryItem({ type: 'tool_group', tools: [] }); - } - let description: string; - try { - description = cliTool.getDescription(args); - } catch (e) { - description = `Error: Unable to get description: ${getErrorMessage(e)}`; - } - const toolCallDisplay: IndividualToolCallDisplay = { - callId, - name: cliTool.displayName, - description, - status: ToolCallStatus.Pending, - resultDisplay: undefined, - confirmationDetails: undefined, - }; - setPendingHistoryItem((pending) => - pending?.type === 'tool_group' - ? { ...pending, tools: [...pending.tools, toolCallDisplay] } - : null, - ); - }; - - const handleToolCallResponseEvent = ( - eventValue: ToolCallResponseEvent['value'], - ) => { - const status = eventValue.error - ? ToolCallStatus.Error - : ToolCallStatus.Success; - updateFunctionResponseUI(eventValue, status); - }; - - const handleToolCallConfirmationEvent = ( - eventValue: ToolCallConfirmationEvent['value'], - ) => { - const confirmationDetails = wireConfirmationSubmission(eventValue); - updateConfirmingFunctionStatusUI( - eventValue.request.callId, - confirmationDetails, - ); - }; - const handleUserCancelledEvent = (userMessageTimestamp: number) => { if (pendingHistoryItemRef.current) { if (pendingHistoryItemRef.current.type === 'tool_group') { @@ -500,6 +349,7 @@ export const useGeminiStream = ( userMessageTimestamp, ); setIsResponding(false); + cancel(); }; const handleErrorEvent = ( @@ -521,7 +371,7 @@ export const useGeminiStream = ( userMessageTimestamp: number, ): Promise<StreamProcessingStatus> => { let geminiMessageBuffer = ''; - + const toolCallRequests: ToolCallRequestInfo[] = []; for await (const event of stream) { if (event.type === ServerGeminiEventType.Content) { geminiMessageBuffer = handleContentEvent( @@ -530,12 +380,7 @@ export const useGeminiStream = ( userMessageTimestamp, ); } else if (event.type === ServerGeminiEventType.ToolCallRequest) { - handleToolCallRequestEvent(event.value, userMessageTimestamp); - } else if (event.type === ServerGeminiEventType.ToolCallResponse) { - handleToolCallResponseEvent(event.value); - } else if (event.type === ServerGeminiEventType.ToolCallConfirmation) { - handleToolCallConfirmationEvent(event.value); - return StreamProcessingStatus.PausedForConfirmation; + toolCallRequests.push(event.value); } else if (event.type === ServerGeminiEventType.UserCancelled) { handleUserCancelledEvent(userMessageTimestamp); return StreamProcessingStatus.UserCancelled; @@ -544,9 +389,18 @@ export const useGeminiStream = ( return StreamProcessingStatus.Error; } } + schedule(toolCallRequests); return StreamProcessingStatus.Completed; }; + const streamingState: StreamingState = isResponding + ? StreamingState.Responding + : pendingToolCalls?.tools.some( + (t) => t.status === ToolCallStatus.Confirming, + ) + ? StreamingState.WaitingForConfirmation + : StreamingState.Idle; + const submitQuery = useCallback( async (query: PartListUnion) => { if (isResponding) return; @@ -625,20 +479,15 @@ export const useGeminiStream = ( ], ); - const streamingState: StreamingState = isResponding - ? StreamingState.Responding - : pendingConfirmations(pendingHistoryItemRef.current) - ? StreamingState.WaitingForConfirmation - : StreamingState.Idle; + const pendingHistoryItems = [ + pendingHistoryItemRef.current, + pendingToolCalls, + ].filter((i) => i !== undefined && i !== null); return { streamingState, submitQuery, initError, - pendingHistoryItem: pendingHistoryItemRef.current, + pendingHistoryItems, }; }; - -const pendingConfirmations = (item: HistoryItemWithoutId | null): boolean => - item?.type === 'tool_group' && - item.tools.some((t) => t.status === ToolCallStatus.Confirming); |
