diff options
| author | N. Taylor Mullen <[email protected]> | 2025-06-08 15:42:49 -0700 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-06-08 22:42:49 +0000 |
| commit | f2ea78d0e4e5d25ab3cc25dc9f1492135630c9be (patch) | |
| tree | cdc80f281095a279c1c1746a5b4c1fbfa008dc20 /packages/cli/src/ui/hooks/useGeminiStream.ts | |
| parent | 7868ef82299ae1da5a09334f67d57eb3b472563a (diff) | |
fix(tool-scheduler): Correctly pipe cancellation signal to tool calls (#852)
Diffstat (limited to 'packages/cli/src/ui/hooks/useGeminiStream.ts')
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 57 |
1 files changed, 29 insertions, 28 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 3b3d01e0..2b47ae6f 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -83,28 +83,24 @@ export const useGeminiStream = ( useStateAndRef<HistoryItemWithoutId | null>(null); const logger = useLogger(); - const [ - toolCalls, - scheduleToolCalls, - cancelAllToolCalls, - markToolsAsSubmitted, - ] = useReactToolScheduler( - (completedToolCallsFromScheduler) => { - // This onComplete is called when ALL scheduled tools for a given batch are done. - if (completedToolCallsFromScheduler.length > 0) { - // Add the final state of these tools to the history for display. - // The new useEffect will handle submitting their responses. - addItem( - mapTrackedToolCallsToDisplay( - completedToolCallsFromScheduler as TrackedToolCall[], - ), - Date.now(), - ); - } - }, - config, - setPendingHistoryItem, - ); + const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] = + useReactToolScheduler( + (completedToolCallsFromScheduler) => { + // This onComplete is called when ALL scheduled tools for a given batch are done. + if (completedToolCallsFromScheduler.length > 0) { + // Add the final state of these tools to the history for display. + // The new useEffect will handle submitting their responses. + addItem( + mapTrackedToolCallsToDisplay( + completedToolCallsFromScheduler as TrackedToolCall[], + ), + Date.now(), + ); + } + }, + config, + setPendingHistoryItem, + ); const pendingToolCallGroupDisplay = useMemo( () => @@ -143,10 +139,15 @@ export const useGeminiStream = ( return StreamingState.Idle; }, [isResponding, toolCalls]); + useEffect(() => { + if (streamingState === StreamingState.Idle) { + abortControllerRef.current = null; + } + }, [streamingState]); + useInput((_input, key) => { if (streamingState !== StreamingState.Idle && key.escape) { abortControllerRef.current?.abort(); - cancelAllToolCalls(); // Also cancel any pending/executing tool calls } }); @@ -191,7 +192,7 @@ export const useGeminiStream = ( name: toolName, args: toolArgs, }; - scheduleToolCalls([toolCallRequest]); + scheduleToolCalls([toolCallRequest], abortSignal); } return { queryToSend: null, shouldProceed: false }; // Handled by scheduling the tool } @@ -330,9 +331,8 @@ export const useGeminiStream = ( userMessageTimestamp, ); setIsResponding(false); - cancelAllToolCalls(); }, - [addItem, pendingHistoryItemRef, setPendingHistoryItem, cancelAllToolCalls], + [addItem, pendingHistoryItemRef, setPendingHistoryItem], ); const handleErrorEvent = useCallback( @@ -365,6 +365,7 @@ export const useGeminiStream = ( async ( stream: AsyncIterable<GeminiEvent>, userMessageTimestamp: number, + signal: AbortSignal, ): Promise<StreamProcessingStatus> => { let geminiMessageBuffer = ''; const toolCallRequests: ToolCallRequestInfo[] = []; @@ -401,7 +402,7 @@ export const useGeminiStream = ( } } if (toolCallRequests.length > 0) { - scheduleToolCalls(toolCallRequests); + scheduleToolCalls(toolCallRequests, signal); } return StreamProcessingStatus.Completed; }, @@ -453,6 +454,7 @@ export const useGeminiStream = ( const processingStatus = await processGeminiStreamEvents( stream, userMessageTimestamp, + abortSignal, ); if (processingStatus === StreamProcessingStatus.UserCancelled) { @@ -476,7 +478,6 @@ export const useGeminiStream = ( ); } } finally { - abortControllerRef.current = null; // Always reset setIsResponding(false); } }, |
