diff options
| author | N. Taylor Mullen <[email protected]> | 2025-06-01 14:16:24 -0700 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-06-01 14:16:24 -0700 |
| commit | f2a8d39f42ae88c1b7a9a5a75854363a53444ca2 (patch) | |
| tree | 181d8eb3f1b1602f985fba4d2522b06c6c4f2eb6 /packages/cli/src/ui/hooks/useGeminiStream.ts | |
| parent | edc12e416d0b9daf24ede50cb18b012cb2b6e18a (diff) | |
refactor: Centralize tool scheduling logic and simplify React hook (#670)
Diffstat (limited to 'packages/cli/src/ui/hooks/useGeminiStream.ts')
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 213 |
1 files changed, 100 insertions, 113 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 77be6879..35e5a26a 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -16,20 +16,15 @@ import { isNodeError, Config, MessageSenderType, - ServerToolCallConfirmationDetails, - ToolCallResponseInfo, - ToolEditConfirmationDetails, - ToolExecuteConfirmationDetails, - ToolResultDisplay, ToolCallRequestInfo, } from '@gemini-code/core'; -import { type PartListUnion, type Part } from '@google/genai'; +import { type PartListUnion } from '@google/genai'; import { StreamingState, - ToolCallStatus, HistoryItemWithoutId, HistoryItemToolGroup, MessageType, + ToolCallStatus, } from '../types.js'; import { isAtCommand } from '../utils/commandUtils.js'; import { useShellCommandProcessor } from './shellCommandProcessor.js'; @@ -38,7 +33,13 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js'; import { useStateAndRef } from './useStateAndRef.js'; import { UseHistoryManagerReturn } from './useHistoryManager.js'; import { useLogger } from './useLogger.js'; -import { useToolScheduler, mapToDisplay } from './useToolScheduler.js'; +import { + useReactToolScheduler, + mapToDisplay as mapTrackedToolCallsToDisplay, + TrackedToolCall, + TrackedCompletedToolCall, + TrackedCancelledToolCall, +} from './useReactToolScheduler.js'; import { GeminiChat } from '@gemini-code/core/src/core/geminiChat.js'; export function mergePartListUnions(list: PartListUnion[]): PartListUnion { @@ -60,12 +61,11 @@ enum StreamProcessingStatus { } /** - * Hook to manage the Gemini stream, handle user input, process commands, - * and interact with the Gemini API and history manager. + * Manages the Gemini stream, including user input, command processing, + * API interaction, and tool call lifecycle. */ export const useGeminiStream = ( addItem: UseHistoryManagerReturn['addItem'], - refreshStatic: () => void, setShowHelp: React.Dispatch<React.SetStateAction<boolean>>, config: Config, onDebugMessage: (message: string) => void, @@ -82,27 +82,33 @@ export const useGeminiStream = ( const [pendingHistoryItemRef, setPendingHistoryItem] = useStateAndRef<HistoryItemWithoutId | null>(null); const logger = useLogger(); - const [toolCalls, schedule, cancel] = useToolScheduler( - (tools) => { - if (tools.length) { - addItem(mapToDisplay(tools), Date.now()); - const toolResponses = tools - .filter( - (t) => - t.status === 'error' || - t.status === 'cancelled' || - t.status === 'success', - ) - .map((t) => t.response.responseParts); - submitQuery(mergePartListUnions(toolResponses)); + const [ + toolCalls, + scheduleToolCalls, + cancelAllToolCalls, + markToolsAsSubmitted, + ] = useReactToolScheduler( + (completedToolCallsFromScheduler) => { + // This onComplete is called when ALL scheduled tools for a given batch are done. + if (completedToolCallsFromScheduler.length > 0) { + // Add the final state of these tools to the history for display. + // The new useEffect will handle submitting their responses. + addItem( + mapTrackedToolCallsToDisplay( + completedToolCallsFromScheduler as TrackedToolCall[], + ), + Date.now(), + ); } }, config, setPendingHistoryItem, ); - const pendingToolCalls = useMemo( - () => (toolCalls.length ? mapToDisplay(toolCalls) : undefined), + + const pendingToolCallGroupDisplay = useMemo( + () => + toolCalls.length ? mapTrackedToolCallsToDisplay(toolCalls) : undefined, [toolCalls], ); @@ -120,16 +126,16 @@ export const useGeminiStream = ( ); const streamingState = useMemo(() => { - if (toolCalls.some((t) => t.status === 'awaiting_approval')) { + if (toolCalls.some((tc) => tc.status === 'awaiting_approval')) { return StreamingState.WaitingForConfirmation; } if ( isResponding || toolCalls.some( - (t) => - t.status === 'executing' || - t.status === 'scheduled' || - t.status === 'validating', + (tc) => + tc.status === 'executing' || + tc.status === 'scheduled' || + tc.status === 'validating', ) ) { return StreamingState.Responding; @@ -153,7 +159,7 @@ export const useGeminiStream = ( useInput((_input, key) => { if (streamingState !== StreamingState.Idle && key.escape) { abortControllerRef.current?.abort(); - cancel(); + cancelAllToolCalls(); // Also cancel any pending/executing tool calls } }); @@ -194,7 +200,7 @@ export const useGeminiStream = ( name: toolName, args: toolArgs, }; - schedule([toolCallRequest]); // schedule expects an array or single object + scheduleToolCalls([toolCallRequest]); } return { queryToSend: null, shouldProceed: false }; // Handled by scheduling the tool } @@ -246,7 +252,7 @@ export const useGeminiStream = ( handleSlashCommand, logger, shellModeActive, - schedule, + scheduleToolCalls, ], ); @@ -275,73 +281,6 @@ export const useGeminiStream = ( return { client: currentClient, chat: chatSessionRef.current }; }, [addItem]); - // --- UI Helper Functions (used by event handlers) --- - const updateFunctionResponseUI = ( - toolResponse: ToolCallResponseInfo, - status: ToolCallStatus, - ) => { - setPendingHistoryItem((item) => - item?.type === 'tool_group' - ? { - ...item, - tools: item.tools.map((tool) => - tool.callId === toolResponse.callId - ? { - ...tool, - status, - resultDisplay: toolResponse.resultDisplay, - } - : tool, - ), - } - : item, - ); - }; - - // Extracted declineToolExecution to be part of wireConfirmationSubmission's closure - // or could be a standalone helper if more params are passed. - // TODO: handle file diff result display stuff - function _declineToolExecution( - declineMessage: string, - status: ToolCallStatus, - request: ServerToolCallConfirmationDetails['request'], - originalDetails: ServerToolCallConfirmationDetails['details'], - ) { - let resultDisplay: ToolResultDisplay | undefined; - if ('fileDiff' in originalDetails) { - resultDisplay = { - fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff, - fileName: (originalDetails as ToolEditConfirmationDetails).fileName, - }; - } else { - resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`; - } - const functionResponse: Part = { - functionResponse: { - id: request.callId, - name: request.name, - response: { error: declineMessage }, - }, - }; - const responseInfo: ToolCallResponseInfo = { - callId: request.callId, - responseParts: functionResponse, - resultDisplay, - error: new Error(declineMessage), - }; - const history = chatSessionRef.current?.getHistory(); - if (history) { - history.push({ role: 'model', parts: [functionResponse] }); - } - updateFunctionResponseUI(responseInfo, status); - - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, Date.now()); - setPendingHistoryItem(null); - } - setIsResponding(false); - } - // --- Stream Event Handlers --- const handleContentEvent = useCallback( @@ -425,9 +364,9 @@ export const useGeminiStream = ( userMessageTimestamp, ); setIsResponding(false); - cancel(); + cancelAllToolCalls(); }, - [addItem, pendingHistoryItemRef, setPendingHistoryItem, cancel], + [addItem, pendingHistoryItemRef, setPendingHistoryItem, cancelAllToolCalls], ); const handleErrorEvent = useCallback( @@ -462,22 +401,22 @@ export const useGeminiStream = ( toolCallRequests.push(event.value); } else if (event.type === ServerGeminiEventType.UserCancelled) { handleUserCancelledEvent(userMessageTimestamp); - cancel(); return StreamProcessingStatus.UserCancelled; } else if (event.type === ServerGeminiEventType.Error) { handleErrorEvent(event.value, userMessageTimestamp); return StreamProcessingStatus.Error; } } - schedule(toolCallRequests); + if (toolCallRequests.length > 0) { + scheduleToolCalls(toolCallRequests); + } return StreamProcessingStatus.Completed; }, [ handleContentEvent, handleUserCancelledEvent, - cancel, handleErrorEvent, - schedule, + scheduleToolCalls, ], ); @@ -545,21 +484,69 @@ export const useGeminiStream = ( } }, [ + streamingState, setShowHelp, - addItem, - setInitError, - ensureChatSession, prepareQueryForGemini, + ensureChatSession, processGeminiStreamEvents, - setPendingHistoryItem, pendingHistoryItemRef, - streamingState, + addItem, + setPendingHistoryItem, + setInitError, ], ); + /** + * Automatically submits responses for completed tool calls. + * This effect runs when `toolCalls` or `isResponding` changes. + * It ensures that tool responses are sent back to Gemini only when + * all processing for a given set of tools is finished and Gemini + * is not already generating a response. + */ + useEffect(() => { + if (isResponding) { + return; + } + + const completedAndReadyToSubmitTools = toolCalls.filter( + ( + tc: TrackedToolCall, + ): tc is TrackedCompletedToolCall | TrackedCancelledToolCall => { + const isTerminalState = + tc.status === 'success' || + tc.status === 'error' || + tc.status === 'cancelled'; + + if (isTerminalState) { + const completedOrCancelledCall = tc as + | TrackedCompletedToolCall + | TrackedCancelledToolCall; + return ( + !completedOrCancelledCall.responseSubmittedToGemini && + completedOrCancelledCall.response?.responseParts !== undefined + ); + } + return false; + }, + ); + + if (completedAndReadyToSubmitTools.length > 0) { + const responsesToSend: PartListUnion[] = + completedAndReadyToSubmitTools.map( + (toolCall) => toolCall.response.responseParts, + ); + const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map( + (toolCall) => toolCall.request.callId, + ); + + markToolsAsSubmitted(callIdsToMarkAsSubmitted); + submitQuery(mergePartListUnions(responsesToSend)); + } + }, [toolCalls, isResponding, submitQuery, markToolsAsSubmitted, addItem]); + const pendingHistoryItems = [ pendingHistoryItemRef.current, - pendingToolCalls, + pendingToolCallGroupDisplay, ].filter((i) => i !== undefined && i !== null); return { |
