diff options
| author | N. Taylor Mullen <[email protected]> | 2025-06-08 11:14:45 -0700 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-06-08 11:14:45 -0700 |
| commit | 241c404573a8dd8c032dde5478a9bec95dd83a19 (patch) | |
| tree | 7b0f6e812a19dfe1dbe09254781af605c370d12f /packages/cli/src | |
| parent | 9efca40dae2e75477af1a20df4e3e65bf8dfe93d (diff) | |
fix(cli): correctly handle tool invocation cancellation (#844)
Diffstat (limited to 'packages/cli/src')
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.test.tsx | 52 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 46 |
2 files changed, 94 insertions, 4 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index bd0f0520..1335eb8e 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -30,6 +30,7 @@ const MockedGeminiClientClass = vi.hoisted(() => // _config this.startChat = mockStartChat; this.sendMessageStream = mockSendMessageStream; + this.addHistory = vi.fn(); }), ); @@ -267,6 +268,7 @@ describe('useGeminiStream', () => { () => ({ getToolSchemaList: vi.fn(() => []) }) as any, ), getGeminiClient: mockGetGeminiClient, + addHistory: vi.fn(), } as unknown as Config; mockOnDebugMessage = vi.fn(); mockHandleSlashCommand = vi.fn().mockReturnValue(false); @@ -294,7 +296,10 @@ describe('useGeminiStream', () => { .mockReturnValue((async function* () {})()); }); - const renderTestHook = (initialToolCalls: TrackedToolCall[] = []) => { + const renderTestHook = ( + initialToolCalls: TrackedToolCall[] = [], + geminiClient?: any, + ) => { mockUseReactToolScheduler.mockReturnValue([ initialToolCalls, mockScheduleToolCalls, @@ -302,9 +307,11 @@ describe('useGeminiStream', () => { mockMarkToolsAsSubmitted, ]); + const client = geminiClient || mockConfig.getGeminiClient(); + const { result, rerender } = renderHook(() => useGeminiStream( - mockConfig.getGeminiClient(), + client, mockAddItem as unknown as UseHistoryManagerReturn['addItem'], mockSetShowHelp, mockConfig, @@ -318,6 +325,7 @@ describe('useGeminiStream', () => { rerender, mockMarkToolsAsSubmitted, mockSendMessageStream, + client, // mockFilter removed }; }; @@ -444,4 +452,44 @@ describe('useGeminiStream', () => { expect.any(AbortSignal), ); }); + + it('should handle all tool calls being cancelled', async () => { + const toolCalls: TrackedToolCall[] = [ + { + request: { callId: '1', name: 'testTool', args: {} }, + status: 'cancelled', + response: { + callId: '1', + responseParts: [{ text: 'cancelled' }], + error: undefined, + resultDisplay: 'Tool 1 cancelled display', + }, + responseSubmittedToGemini: false, + tool: { + name: 'testTool', + description: 'desc', + getDescription: vi.fn(), + } as any, + }, + ]; + + const client = new MockedGeminiClientClass(mockConfig); + const { mockMarkToolsAsSubmitted, rerender } = renderTestHook( + toolCalls, + client, + ); + + await act(async () => { + rerender({} as any); + }); + + await waitFor(() => { + expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']); + expect(client.addHistory).toHaveBeenCalledTimes(2); + expect(client.addHistory).toHaveBeenCalledWith({ + role: 'user', + parts: [{ text: 'cancelled' }], + }); + }); + }); }); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 5e741547..3b3d01e0 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -19,7 +19,7 @@ import { ToolCallRequestInfo, logUserPrompt, } from '@gemini-cli/core'; -import { type PartListUnion } from '@google/genai'; +import { type Part, type PartListUnion } from '@google/genai'; import { StreamingState, HistoryItemWithoutId, @@ -531,6 +531,41 @@ export const useGeminiStream = ( completedAndReadyToSubmitTools.length > 0 && completedAndReadyToSubmitTools.length === toolCalls.length ) { + // If all the tools were cancelled, don't submit a response to Gemini. + const allToolsCancelled = completedAndReadyToSubmitTools.every( + (tc) => tc.status === 'cancelled', + ); + + if (allToolsCancelled) { + if (geminiClient) { + // We need to manually add the function responses to the history + // so the model knows the tools were cancelled. + const responsesToAdd = completedAndReadyToSubmitTools.flatMap( + (toolCall) => toolCall.response.responseParts, + ); + for (const response of responsesToAdd) { + let parts: Part[]; + if (Array.isArray(response)) { + parts = response; + } else if (typeof response === 'string') { + parts = [{ text: response }]; + } else { + parts = [response]; + } + geminiClient.addHistory({ + role: 'user', + parts, + }); + } + } + + const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map( + (toolCall) => toolCall.request.callId, + ); + markToolsAsSubmitted(callIdsToMarkAsSubmitted); + return; + } + const responsesToSend: PartListUnion[] = completedAndReadyToSubmitTools.map( (toolCall) => toolCall.response.responseParts, @@ -542,7 +577,14 @@ export const useGeminiStream = ( markToolsAsSubmitted(callIdsToMarkAsSubmitted); submitQuery(mergePartListUnions(responsesToSend)); } - }, [toolCalls, isResponding, submitQuery, markToolsAsSubmitted, addItem]); + }, [ + toolCalls, + isResponding, + submitQuery, + markToolsAsSubmitted, + addItem, + geminiClient, + ]); const pendingHistoryItems = [ pendingHistoryItemRef.current, |
