summaryrefslogtreecommitdiff
path: root/packages/cli/src/ui/hooks
diff options
context:
space:
mode:
authorN. Taylor Mullen <[email protected]>2025-06-08 11:14:45 -0700
committerGitHub <[email protected]>2025-06-08 11:14:45 -0700
commit241c404573a8dd8c032dde5478a9bec95dd83a19 (patch)
tree7b0f6e812a19dfe1dbe09254781af605c370d12f /packages/cli/src/ui/hooks
parent9efca40dae2e75477af1a20df4e3e65bf8dfe93d (diff)
fix(cli): correctly handle tool invocation cancellation (#844)
Diffstat (limited to 'packages/cli/src/ui/hooks')
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.test.tsx52
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.ts46
2 files changed, 94 insertions, 4 deletions
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
index bd0f0520..1335eb8e 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
+++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
@@ -30,6 +30,7 @@ const MockedGeminiClientClass = vi.hoisted(() =>
// _config
this.startChat = mockStartChat;
this.sendMessageStream = mockSendMessageStream;
+ this.addHistory = vi.fn();
}),
);
@@ -267,6 +268,7 @@ describe('useGeminiStream', () => {
() => ({ getToolSchemaList: vi.fn(() => []) }) as any,
),
getGeminiClient: mockGetGeminiClient,
+ addHistory: vi.fn(),
} as unknown as Config;
mockOnDebugMessage = vi.fn();
mockHandleSlashCommand = vi.fn().mockReturnValue(false);
@@ -294,7 +296,10 @@ describe('useGeminiStream', () => {
.mockReturnValue((async function* () {})());
});
- const renderTestHook = (initialToolCalls: TrackedToolCall[] = []) => {
+ const renderTestHook = (
+ initialToolCalls: TrackedToolCall[] = [],
+ geminiClient?: any,
+ ) => {
mockUseReactToolScheduler.mockReturnValue([
initialToolCalls,
mockScheduleToolCalls,
@@ -302,9 +307,11 @@ describe('useGeminiStream', () => {
mockMarkToolsAsSubmitted,
]);
+ const client = geminiClient || mockConfig.getGeminiClient();
+
const { result, rerender } = renderHook(() =>
useGeminiStream(
- mockConfig.getGeminiClient(),
+ client,
mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
mockSetShowHelp,
mockConfig,
@@ -318,6 +325,7 @@ describe('useGeminiStream', () => {
rerender,
mockMarkToolsAsSubmitted,
mockSendMessageStream,
+ client,
// mockFilter removed
};
};
@@ -444,4 +452,44 @@ describe('useGeminiStream', () => {
expect.any(AbortSignal),
);
});
+
+ it('should handle all tool calls being cancelled', async () => {
+ const toolCalls: TrackedToolCall[] = [
+ {
+ request: { callId: '1', name: 'testTool', args: {} },
+ status: 'cancelled',
+ response: {
+ callId: '1',
+ responseParts: [{ text: 'cancelled' }],
+ error: undefined,
+ resultDisplay: 'Tool 1 cancelled display',
+ },
+ responseSubmittedToGemini: false,
+ tool: {
+ name: 'testTool',
+ description: 'desc',
+ getDescription: vi.fn(),
+ } as any,
+ },
+ ];
+
+ const client = new MockedGeminiClientClass(mockConfig);
+ const { mockMarkToolsAsSubmitted, rerender } = renderTestHook(
+ toolCalls,
+ client,
+ );
+
+ await act(async () => {
+ rerender({} as any);
+ });
+
+ await waitFor(() => {
+ expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']);
+ expect(client.addHistory).toHaveBeenCalledTimes(2);
+ expect(client.addHistory).toHaveBeenCalledWith({
+ role: 'user',
+ parts: [{ text: 'cancelled' }],
+ });
+ });
+ });
});
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index 5e741547..3b3d01e0 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -19,7 +19,7 @@ import {
ToolCallRequestInfo,
logUserPrompt,
} from '@gemini-cli/core';
-import { type PartListUnion } from '@google/genai';
+import { type Part, type PartListUnion } from '@google/genai';
import {
StreamingState,
HistoryItemWithoutId,
@@ -531,6 +531,41 @@ export const useGeminiStream = (
completedAndReadyToSubmitTools.length > 0 &&
completedAndReadyToSubmitTools.length === toolCalls.length
) {
+ // If all the tools were cancelled, don't submit a response to Gemini.
+ const allToolsCancelled = completedAndReadyToSubmitTools.every(
+ (tc) => tc.status === 'cancelled',
+ );
+
+ if (allToolsCancelled) {
+ if (geminiClient) {
+ // We need to manually add the function responses to the history
+ // so the model knows the tools were cancelled.
+ const responsesToAdd = completedAndReadyToSubmitTools.flatMap(
+ (toolCall) => toolCall.response.responseParts,
+ );
+ for (const response of responsesToAdd) {
+ let parts: Part[];
+ if (Array.isArray(response)) {
+ parts = response;
+ } else if (typeof response === 'string') {
+ parts = [{ text: response }];
+ } else {
+ parts = [response];
+ }
+ geminiClient.addHistory({
+ role: 'user',
+ parts,
+ });
+ }
+ }
+
+ const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map(
+ (toolCall) => toolCall.request.callId,
+ );
+ markToolsAsSubmitted(callIdsToMarkAsSubmitted);
+ return;
+ }
+
const responsesToSend: PartListUnion[] =
completedAndReadyToSubmitTools.map(
(toolCall) => toolCall.response.responseParts,
@@ -542,7 +577,14 @@ export const useGeminiStream = (
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
submitQuery(mergePartListUnions(responsesToSend));
}
- }, [toolCalls, isResponding, submitQuery, markToolsAsSubmitted, addItem]);
+ }, [
+ toolCalls,
+ isResponding,
+ submitQuery,
+ markToolsAsSubmitted,
+ addItem,
+ geminiClient,
+ ]);
const pendingHistoryItems = [
pendingHistoryItemRef.current,