diff options
| author | Abhi <[email protected]> | 2025-06-22 01:35:36 -0400 |
|---|---|---|
| committer | GitHub <[email protected]> | 2025-06-22 01:35:36 -0400 |
| commit | c9950b3cb273246d801a5cbb04cf421d4c5e39c4 (patch) | |
| tree | 0acd0de4ef11c6031c70489bba6063bbba4ca8f1 /packages/cli/src | |
| parent | 5cf8dc4f0784408f4c2fcfc56d6e834facccf4a3 (diff) | |
feat: Add client-initiated tool call handling (#1292)
Diffstat (limited to 'packages/cli/src')
| -rw-r--r-- | packages/cli/src/nonInteractiveCli.ts | 1 | ||||
| -rw-r--r-- | packages/cli/src/ui/App.tsx | 1 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.test.tsx | 324 | ||||
| -rw-r--r-- | packages/cli/src/ui/hooks/useGeminiStream.ts | 115 |
4 files changed, 319 insertions, 122 deletions
diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index c5a89575..01ec62c8 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -89,6 +89,7 @@ export async function runNonInteractive( callId, name: fc.name as string, args: (fc.args ?? {}) as Record<string, unknown>, + isClientInitiated: false, }; const toolResponse = await executeToolCall( diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx index 48d045e3..43936778 100644 --- a/packages/cli/src/ui/App.tsx +++ b/packages/cli/src/ui/App.tsx @@ -362,6 +362,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => { shellModeActive, getPreferredEditor, onAuthError, + performMemoryRefresh, ); pendingHistoryItems.push(...pendingGeminiHistoryItems); const { elapsedTime, currentLoadingPhrase } = diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index ac168dcd..f8cc61bc 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -371,6 +371,7 @@ describe('useGeminiStream', () => { props.shellModeActive, () => 'vscode' as EditorType, () => {}, + () => Promise.resolve(), ); }, { @@ -389,6 +390,7 @@ describe('useGeminiStream', () => { >, shellModeActive: false, loadedSettings: mockLoadedSettings, + toolCalls: initialToolCalls, }, }, ); @@ -404,7 +406,12 @@ describe('useGeminiStream', () => { it('should not submit tool responses if not all tool calls are completed', () => { const toolCalls: TrackedToolCall[] = [ { - request: { callId: 'call1', name: 'tool1', args: {} }, + request: { + callId: 'call1', + name: 'tool1', + args: {}, + isClientInitiated: false, + }, status: 'success', responseSubmittedToGemini: false, response: { @@ -452,133 +459,138 @@ describe('useGeminiStream', () => { const toolCall2ResponseParts: PartListUnion = [ { text: 'tool 2 final response' }, ]; - - // Simplified toolCalls to ensure the filter logic is the focus - const simplifiedToolCalls: TrackedToolCall[] = [ + const completedToolCalls: TrackedToolCall[] = [ { - request: { callId: 'call1', name: 'tool1', args: {} }, - status: 'success', - responseSubmittedToGemini: false, - response: { + request: { callId: 'call1', - responseParts: toolCall1ResponseParts, - error: undefined, - resultDisplay: 'Tool 1 success display', - }, - tool: { name: 'tool1', - description: 'desc', - getDescription: vi.fn(), - } as any, - startTime: Date.now(), - endTime: Date.now(), + args: {}, + isClientInitiated: false, + }, + status: 'success', + responseSubmittedToGemini: false, + response: { callId: 'call1', responseParts: toolCall1ResponseParts }, } as TrackedCompletedToolCall, { - request: { callId: 'call2', name: 'tool2', args: {} }, - status: 'cancelled', - responseSubmittedToGemini: false, - response: { + request: { callId: 'call2', - responseParts: toolCall2ResponseParts, - error: undefined, - resultDisplay: 'Tool 2 cancelled display', - }, - tool: { name: 'tool2', - description: 'desc', - getDescription: vi.fn(), - } as any, - startTime: Date.now(), - endTime: Date.now(), - reason: 'test cancellation', - } as TrackedCancelledToolCall, + args: {}, + isClientInitiated: false, + }, + status: 'error', + responseSubmittedToGemini: false, + response: { callId: 'call2', responseParts: toolCall2ResponseParts }, + } as TrackedCompletedToolCall, // Treat error as a form of completion for submission ]; - const { - rerender, + // 1. On the first render, there are no tool calls. + mockUseReactToolScheduler.mockReturnValue([ + [], + mockScheduleToolCalls, mockMarkToolsAsSubmitted, - mockSendMessageStream: localMockSendMessageStream, - client, - } = renderTestHook(simplifiedToolCalls); + ]); + const { rerender } = renderHook(() => + useGeminiStream( + new MockedGeminiClientClass(mockConfig), + [], + mockAddItem, + mockSetShowHelp, + mockConfig, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => 'vscode' as EditorType, + () => {}, + () => Promise.resolve(), + ), + ); + // 2. Before the second render, change the mock to return the completed tools. + mockUseReactToolScheduler.mockReturnValue([ + completedToolCalls, + mockScheduleToolCalls, + mockMarkToolsAsSubmitted, + ]); + + // 3. Trigger a re-render. The hook will now receive the completed tools, causing the effect to run. act(() => { - rerender({ - client, - history: [], - addItem: mockAddItem, - setShowHelp: mockSetShowHelp, - config: mockConfig, - onDebugMessage: mockOnDebugMessage, - handleSlashCommand: - mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand, - shellModeActive: false, - loadedSettings: mockLoadedSettings, - }); + rerender(); }); await waitFor(() => { - expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0); - expect(localMockSendMessageStream).toHaveBeenCalledTimes(0); + expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(1); + expect(mockSendMessageStream).toHaveBeenCalledTimes(1); }); const expectedMergedResponse = mergePartListUnions([ toolCall1ResponseParts, toolCall2ResponseParts, ]); - expect(localMockSendMessageStream).toHaveBeenCalledWith( + expect(mockSendMessageStream).toHaveBeenCalledWith( expectedMergedResponse, expect.any(AbortSignal), ); }); it('should handle all tool calls being cancelled', async () => { - const toolCalls: TrackedToolCall[] = [ + const cancelledToolCalls: TrackedToolCall[] = [ { - request: { callId: '1', name: 'testTool', args: {} }, - status: 'cancelled', - response: { + request: { callId: '1', - responseParts: [{ text: 'cancelled' }], - error: undefined, - resultDisplay: 'Tool 1 cancelled display', + name: 'testTool', + args: {}, + isClientInitiated: false, }, + status: 'cancelled', + response: { callId: '1', responseParts: [{ text: 'cancelled' }] }, responseSubmittedToGemini: false, - tool: { - name: 'testTool', - description: 'desc', - getDescription: vi.fn(), - } as any, - }, + } as TrackedCancelledToolCall, ]; - const client = new MockedGeminiClientClass(mockConfig); - const { mockMarkToolsAsSubmitted, rerender } = renderTestHook( - toolCalls, - client, + + // 1. First render: no tool calls. + mockUseReactToolScheduler.mockReturnValue([ + [], + mockScheduleToolCalls, + mockMarkToolsAsSubmitted, + ]); + const { rerender } = renderHook(() => + useGeminiStream( + client, + [], + mockAddItem, + mockSetShowHelp, + mockConfig, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => 'vscode' as EditorType, + () => {}, + () => Promise.resolve(), + ), ); + // 2. Second render: tool calls are now cancelled. + mockUseReactToolScheduler.mockReturnValue([ + cancelledToolCalls, + mockScheduleToolCalls, + mockMarkToolsAsSubmitted, + ]); + + // 3. Trigger the re-render. act(() => { - rerender({ - client, - history: [], - addItem: mockAddItem, - setShowHelp: mockSetShowHelp, - config: mockConfig, - onDebugMessage: mockOnDebugMessage, - handleSlashCommand: - mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand, - shellModeActive: false, - loadedSettings: mockLoadedSettings, - }); + rerender(); }); await waitFor(() => { - expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0); - expect(client.addHistory).toHaveBeenCalledTimes(2); + expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']); expect(client.addHistory).toHaveBeenCalledWith({ role: 'user', parts: [{ text: 'cancelled' }], }); + // Ensure we do NOT call back to the API + expect(mockSendMessageStream).not.toHaveBeenCalled(); }); }); @@ -708,7 +720,6 @@ describe('useGeminiStream', () => { loadedSettings: mockLoadedSettings, // This is the key part of the test: update the toolCalls array // to simulate the tool finishing. - // @ts-expect-error - we are adding a property to the props object toolCalls: completedToolCalls, }); }); @@ -874,4 +885,145 @@ describe('useGeminiStream', () => { expect(abortSpy).not.toHaveBeenCalled(); }); }); + + describe('Client-Initiated Tool Calls', () => { + it('should execute a client-initiated tool without sending a response to Gemini', async () => { + const clientToolRequest = { + shouldScheduleTool: true, + toolName: 'save_memory', + toolArgs: { fact: 'test fact' }, + }; + mockHandleSlashCommand.mockResolvedValue(clientToolRequest); + + const completedToolCall: TrackedCompletedToolCall = { + request: { + callId: 'client-call-1', + name: clientToolRequest.toolName, + args: clientToolRequest.toolArgs, + isClientInitiated: true, + }, + status: 'success', + responseSubmittedToGemini: false, + response: { + callId: 'client-call-1', + responseParts: [{ text: 'Memory saved' }], + resultDisplay: 'Success: Memory saved', + error: undefined, + }, + tool: { + name: clientToolRequest.toolName, + description: 'Saves memory', + getDescription: vi.fn(), + } as any, + }; + + // 1. Initial render state: no tool calls + mockUseReactToolScheduler.mockReturnValue([ + [], + mockScheduleToolCalls, + mockMarkToolsAsSubmitted, + ]); + + const { result, rerender } = renderHook(() => + useGeminiStream( + new MockedGeminiClientClass(mockConfig), + [], + mockAddItem, + mockSetShowHelp, + mockConfig, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => 'vscode' as EditorType, + () => {}, + () => Promise.resolve(), + ), + ); + + // --- User runs the slash command --- + await act(async () => { + await result.current.submitQuery('/memory add "test fact"'); + }); + + // The command handler schedules the tool. Now we simulate the tool completing. + // 2. Before the next render, set the mock to return the completed tool. + mockUseReactToolScheduler.mockReturnValue([ + [completedToolCall], + mockScheduleToolCalls, + mockMarkToolsAsSubmitted, + ]); + + // 3. Trigger a re-render to process the completed tool. + act(() => { + rerender(); + }); + + // --- Assert the outcome --- + await waitFor(() => { + // The tool should be marked as submitted locally + expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith([ + 'client-call-1', + ]); + // Crucially, no message should be sent to the Gemini API + expect(mockSendMessageStream).not.toHaveBeenCalled(); + }); + }); + }); + + describe('Memory Refresh on save_memory', () => { + it('should call performMemoryRefresh when a save_memory tool call completes successfully', async () => { + const mockPerformMemoryRefresh = vi.fn(); + const completedToolCall: TrackedCompletedToolCall = { + request: { + callId: 'save-mem-call-1', + name: 'save_memory', + args: { fact: 'test' }, + isClientInitiated: true, + }, + status: 'success', + responseSubmittedToGemini: false, + response: { + callId: 'save-mem-call-1', + responseParts: [{ text: 'Memory saved' }], + resultDisplay: 'Success: Memory saved', + error: undefined, + }, + tool: { + name: 'save_memory', + description: 'Saves memory', + getDescription: vi.fn(), + } as any, + }; + + mockUseReactToolScheduler.mockReturnValue([ + [completedToolCall], + mockScheduleToolCalls, + mockMarkToolsAsSubmitted, + ]); + + const { rerender } = renderHook(() => + useGeminiStream( + new MockedGeminiClientClass(mockConfig), + [], + mockAddItem, + mockSetShowHelp, + mockConfig, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => 'vscode' as EditorType, + () => {}, + mockPerformMemoryRefresh, + ), + ); + + act(() => { + rerender(); + }); + + await waitFor(() => { + expect(mockPerformMemoryRefresh).toHaveBeenCalledTimes(1); + }); + }); + }); }); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index fcfa1c57..09b14666 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -89,6 +89,7 @@ export const useGeminiStream = ( shellModeActive: boolean, getPreferredEditor: () => EditorType | undefined, onAuthError: () => void, + performMemoryRefresh: () => Promise<void>, ) => { const [initError, setInitError] = useState<string | null>(null); const abortControllerRef = useRef<AbortController | null>(null); @@ -97,6 +98,7 @@ export const useGeminiStream = ( const [thought, setThought] = useState<ThoughtSummary | null>(null); const [pendingHistoryItemRef, setPendingHistoryItem] = useStateAndRef<HistoryItemWithoutId | null>(null); + const processedMemoryToolsRef = useRef<Set<string>>(new Set()); const logger = useLogger(); const { startNewTurn, addUsage } = useSessionStats(); const gitService = useMemo(() => { @@ -234,6 +236,7 @@ export const useGeminiStream = ( callId: `${toolName}-${Date.now()}-${Math.random().toString(16).slice(2)}`, name: toolName, args: toolArgs, + isClientInitiated: true, }; scheduleToolCalls([toolCallRequest], abortSignal); } @@ -566,38 +569,77 @@ export const useGeminiStream = ( * is not already generating a response. */ useEffect(() => { - if (isResponding) { - return; - } + const run = async () => { + if (isResponding) { + return; + } - const completedAndReadyToSubmitTools = toolCalls.filter( - ( - tc: TrackedToolCall, - ): tc is TrackedCompletedToolCall | TrackedCancelledToolCall => { - const isTerminalState = - tc.status === 'success' || - tc.status === 'error' || - tc.status === 'cancelled'; + const completedAndReadyToSubmitTools = toolCalls.filter( + ( + tc: TrackedToolCall, + ): tc is TrackedCompletedToolCall | TrackedCancelledToolCall => { + const isTerminalState = + tc.status === 'success' || + tc.status === 'error' || + tc.status === 'cancelled'; - if (isTerminalState) { - const completedOrCancelledCall = tc as - | TrackedCompletedToolCall - | TrackedCancelledToolCall; - return ( - !completedOrCancelledCall.responseSubmittedToGemini && - completedOrCancelledCall.response?.responseParts !== undefined - ); - } - return false; - }, - ); + if (isTerminalState) { + const completedOrCancelledCall = tc as + | TrackedCompletedToolCall + | TrackedCancelledToolCall; + return ( + !completedOrCancelledCall.responseSubmittedToGemini && + completedOrCancelledCall.response?.responseParts !== undefined + ); + } + return false; + }, + ); + + // Finalize any client-initiated tools as soon as they are done. + const clientTools = completedAndReadyToSubmitTools.filter( + (t) => t.request.isClientInitiated, + ); + if (clientTools.length > 0) { + markToolsAsSubmitted(clientTools.map((t) => t.request.callId)); + } + + // Identify new, successful save_memory calls that we haven't processed yet. + const newSuccessfulMemorySaves = completedAndReadyToSubmitTools.filter( + (t) => + t.request.name === 'save_memory' && + t.status === 'success' && + !processedMemoryToolsRef.current.has(t.request.callId), + ); + + if (newSuccessfulMemorySaves.length > 0) { + // Perform the refresh only if there are new ones. + void performMemoryRefresh(); + // Mark them as processed so we don't do this again on the next render. + newSuccessfulMemorySaves.forEach((t) => + processedMemoryToolsRef.current.add(t.request.callId), + ); + } + + // Only proceed with submitting to Gemini if ALL tools are complete. + const allToolsAreComplete = + toolCalls.length > 0 && + toolCalls.length === completedAndReadyToSubmitTools.length; + + if (!allToolsAreComplete) { + return; + } + + const geminiTools = completedAndReadyToSubmitTools.filter( + (t) => !t.request.isClientInitiated, + ); + + if (geminiTools.length === 0) { + return; + } - if ( - completedAndReadyToSubmitTools.length > 0 && - completedAndReadyToSubmitTools.length === toolCalls.length - ) { // If all the tools were cancelled, don't submit a response to Gemini. - const allToolsCancelled = completedAndReadyToSubmitTools.every( + const allToolsCancelled = geminiTools.every( (tc) => tc.status === 'cancelled', ); @@ -605,7 +647,7 @@ export const useGeminiStream = ( if (geminiClient) { // We need to manually add the function responses to the history // so the model knows the tools were cancelled. - const responsesToAdd = completedAndReadyToSubmitTools.flatMap( + const responsesToAdd = geminiTools.flatMap( (toolCall) => toolCall.response.responseParts, ); for (const response of responsesToAdd) { @@ -624,18 +666,17 @@ export const useGeminiStream = ( } } - const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map( + const callIdsToMarkAsSubmitted = geminiTools.map( (toolCall) => toolCall.request.callId, ); markToolsAsSubmitted(callIdsToMarkAsSubmitted); return; } - const responsesToSend: PartListUnion[] = - completedAndReadyToSubmitTools.map( - (toolCall) => toolCall.response.responseParts, - ); - const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map( + const responsesToSend: PartListUnion[] = geminiTools.map( + (toolCall) => toolCall.response.responseParts, + ); + const callIdsToMarkAsSubmitted = geminiTools.map( (toolCall) => toolCall.request.callId, ); @@ -643,7 +684,8 @@ export const useGeminiStream = ( submitQuery(mergePartListUnions(responsesToSend), { isContinuation: true, }); - } + }; + void run(); }, [ toolCalls, isResponding, @@ -651,6 +693,7 @@ export const useGeminiStream = ( markToolsAsSubmitted, addItem, geminiClient, + performMemoryRefresh, ]); const pendingHistoryItems = [ |
