summaryrefslogtreecommitdiff
path: root/packages/cli/src
diff options
context:
space:
mode:
authorAbhi <[email protected]>2025-06-22 01:35:36 -0400
committerGitHub <[email protected]>2025-06-22 01:35:36 -0400
commitc9950b3cb273246d801a5cbb04cf421d4c5e39c4 (patch)
tree0acd0de4ef11c6031c70489bba6063bbba4ca8f1 /packages/cli/src
parent5cf8dc4f0784408f4c2fcfc56d6e834facccf4a3 (diff)
feat: Add client-initiated tool call handling (#1292)
Diffstat (limited to 'packages/cli/src')
-rw-r--r--packages/cli/src/nonInteractiveCli.ts1
-rw-r--r--packages/cli/src/ui/App.tsx1
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.test.tsx324
-rw-r--r--packages/cli/src/ui/hooks/useGeminiStream.ts115
4 files changed, 319 insertions, 122 deletions
diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts
index c5a89575..01ec62c8 100644
--- a/packages/cli/src/nonInteractiveCli.ts
+++ b/packages/cli/src/nonInteractiveCli.ts
@@ -89,6 +89,7 @@ export async function runNonInteractive(
callId,
name: fc.name as string,
args: (fc.args ?? {}) as Record<string, unknown>,
+ isClientInitiated: false,
};
const toolResponse = await executeToolCall(
diff --git a/packages/cli/src/ui/App.tsx b/packages/cli/src/ui/App.tsx
index 48d045e3..43936778 100644
--- a/packages/cli/src/ui/App.tsx
+++ b/packages/cli/src/ui/App.tsx
@@ -362,6 +362,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
shellModeActive,
getPreferredEditor,
onAuthError,
+ performMemoryRefresh,
);
pendingHistoryItems.push(...pendingGeminiHistoryItems);
const { elapsedTime, currentLoadingPhrase } =
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
index ac168dcd..f8cc61bc 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
+++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx
@@ -371,6 +371,7 @@ describe('useGeminiStream', () => {
props.shellModeActive,
() => 'vscode' as EditorType,
() => {},
+ () => Promise.resolve(),
);
},
{
@@ -389,6 +390,7 @@ describe('useGeminiStream', () => {
>,
shellModeActive: false,
loadedSettings: mockLoadedSettings,
+ toolCalls: initialToolCalls,
},
},
);
@@ -404,7 +406,12 @@ describe('useGeminiStream', () => {
it('should not submit tool responses if not all tool calls are completed', () => {
const toolCalls: TrackedToolCall[] = [
{
- request: { callId: 'call1', name: 'tool1', args: {} },
+ request: {
+ callId: 'call1',
+ name: 'tool1',
+ args: {},
+ isClientInitiated: false,
+ },
status: 'success',
responseSubmittedToGemini: false,
response: {
@@ -452,133 +459,138 @@ describe('useGeminiStream', () => {
const toolCall2ResponseParts: PartListUnion = [
{ text: 'tool 2 final response' },
];
-
- // Simplified toolCalls to ensure the filter logic is the focus
- const simplifiedToolCalls: TrackedToolCall[] = [
+ const completedToolCalls: TrackedToolCall[] = [
{
- request: { callId: 'call1', name: 'tool1', args: {} },
- status: 'success',
- responseSubmittedToGemini: false,
- response: {
+ request: {
callId: 'call1',
- responseParts: toolCall1ResponseParts,
- error: undefined,
- resultDisplay: 'Tool 1 success display',
- },
- tool: {
name: 'tool1',
- description: 'desc',
- getDescription: vi.fn(),
- } as any,
- startTime: Date.now(),
- endTime: Date.now(),
+ args: {},
+ isClientInitiated: false,
+ },
+ status: 'success',
+ responseSubmittedToGemini: false,
+ response: { callId: 'call1', responseParts: toolCall1ResponseParts },
} as TrackedCompletedToolCall,
{
- request: { callId: 'call2', name: 'tool2', args: {} },
- status: 'cancelled',
- responseSubmittedToGemini: false,
- response: {
+ request: {
callId: 'call2',
- responseParts: toolCall2ResponseParts,
- error: undefined,
- resultDisplay: 'Tool 2 cancelled display',
- },
- tool: {
name: 'tool2',
- description: 'desc',
- getDescription: vi.fn(),
- } as any,
- startTime: Date.now(),
- endTime: Date.now(),
- reason: 'test cancellation',
- } as TrackedCancelledToolCall,
+ args: {},
+ isClientInitiated: false,
+ },
+ status: 'error',
+ responseSubmittedToGemini: false,
+ response: { callId: 'call2', responseParts: toolCall2ResponseParts },
+ } as TrackedCompletedToolCall, // Treat error as a form of completion for submission
];
- const {
- rerender,
+ // 1. On the first render, there are no tool calls.
+ mockUseReactToolScheduler.mockReturnValue([
+ [],
+ mockScheduleToolCalls,
mockMarkToolsAsSubmitted,
- mockSendMessageStream: localMockSendMessageStream,
- client,
- } = renderTestHook(simplifiedToolCalls);
+ ]);
+ const { rerender } = renderHook(() =>
+ useGeminiStream(
+ new MockedGeminiClientClass(mockConfig),
+ [],
+ mockAddItem,
+ mockSetShowHelp,
+ mockConfig,
+ mockOnDebugMessage,
+ mockHandleSlashCommand,
+ false,
+ () => 'vscode' as EditorType,
+ () => {},
+ () => Promise.resolve(),
+ ),
+ );
+ // 2. Before the second render, change the mock to return the completed tools.
+ mockUseReactToolScheduler.mockReturnValue([
+ completedToolCalls,
+ mockScheduleToolCalls,
+ mockMarkToolsAsSubmitted,
+ ]);
+
+ // 3. Trigger a re-render. The hook will now receive the completed tools, causing the effect to run.
act(() => {
- rerender({
- client,
- history: [],
- addItem: mockAddItem,
- setShowHelp: mockSetShowHelp,
- config: mockConfig,
- onDebugMessage: mockOnDebugMessage,
- handleSlashCommand:
- mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
- shellModeActive: false,
- loadedSettings: mockLoadedSettings,
- });
+ rerender();
});
await waitFor(() => {
- expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0);
- expect(localMockSendMessageStream).toHaveBeenCalledTimes(0);
+ expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(1);
+ expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
});
const expectedMergedResponse = mergePartListUnions([
toolCall1ResponseParts,
toolCall2ResponseParts,
]);
- expect(localMockSendMessageStream).toHaveBeenCalledWith(
+ expect(mockSendMessageStream).toHaveBeenCalledWith(
expectedMergedResponse,
expect.any(AbortSignal),
);
});
it('should handle all tool calls being cancelled', async () => {
- const toolCalls: TrackedToolCall[] = [
+ const cancelledToolCalls: TrackedToolCall[] = [
{
- request: { callId: '1', name: 'testTool', args: {} },
- status: 'cancelled',
- response: {
+ request: {
callId: '1',
- responseParts: [{ text: 'cancelled' }],
- error: undefined,
- resultDisplay: 'Tool 1 cancelled display',
+ name: 'testTool',
+ args: {},
+ isClientInitiated: false,
},
+ status: 'cancelled',
+ response: { callId: '1', responseParts: [{ text: 'cancelled' }] },
responseSubmittedToGemini: false,
- tool: {
- name: 'testTool',
- description: 'desc',
- getDescription: vi.fn(),
- } as any,
- },
+ } as TrackedCancelledToolCall,
];
-
const client = new MockedGeminiClientClass(mockConfig);
- const { mockMarkToolsAsSubmitted, rerender } = renderTestHook(
- toolCalls,
- client,
+
+ // 1. First render: no tool calls.
+ mockUseReactToolScheduler.mockReturnValue([
+ [],
+ mockScheduleToolCalls,
+ mockMarkToolsAsSubmitted,
+ ]);
+ const { rerender } = renderHook(() =>
+ useGeminiStream(
+ client,
+ [],
+ mockAddItem,
+ mockSetShowHelp,
+ mockConfig,
+ mockOnDebugMessage,
+ mockHandleSlashCommand,
+ false,
+ () => 'vscode' as EditorType,
+ () => {},
+ () => Promise.resolve(),
+ ),
);
+ // 2. Second render: tool calls are now cancelled.
+ mockUseReactToolScheduler.mockReturnValue([
+ cancelledToolCalls,
+ mockScheduleToolCalls,
+ mockMarkToolsAsSubmitted,
+ ]);
+
+ // 3. Trigger the re-render.
act(() => {
- rerender({
- client,
- history: [],
- addItem: mockAddItem,
- setShowHelp: mockSetShowHelp,
- config: mockConfig,
- onDebugMessage: mockOnDebugMessage,
- handleSlashCommand:
- mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
- shellModeActive: false,
- loadedSettings: mockLoadedSettings,
- });
+ rerender();
});
await waitFor(() => {
- expect(mockMarkToolsAsSubmitted).toHaveBeenCalledTimes(0);
- expect(client.addHistory).toHaveBeenCalledTimes(2);
+ expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith(['1']);
expect(client.addHistory).toHaveBeenCalledWith({
role: 'user',
parts: [{ text: 'cancelled' }],
});
+ // Ensure we do NOT call back to the API
+ expect(mockSendMessageStream).not.toHaveBeenCalled();
});
});
@@ -708,7 +720,6 @@ describe('useGeminiStream', () => {
loadedSettings: mockLoadedSettings,
// This is the key part of the test: update the toolCalls array
// to simulate the tool finishing.
- // @ts-expect-error - we are adding a property to the props object
toolCalls: completedToolCalls,
});
});
@@ -874,4 +885,145 @@ describe('useGeminiStream', () => {
expect(abortSpy).not.toHaveBeenCalled();
});
});
+
+ describe('Client-Initiated Tool Calls', () => {
+ it('should execute a client-initiated tool without sending a response to Gemini', async () => {
+ const clientToolRequest = {
+ shouldScheduleTool: true,
+ toolName: 'save_memory',
+ toolArgs: { fact: 'test fact' },
+ };
+ mockHandleSlashCommand.mockResolvedValue(clientToolRequest);
+
+ const completedToolCall: TrackedCompletedToolCall = {
+ request: {
+ callId: 'client-call-1',
+ name: clientToolRequest.toolName,
+ args: clientToolRequest.toolArgs,
+ isClientInitiated: true,
+ },
+ status: 'success',
+ responseSubmittedToGemini: false,
+ response: {
+ callId: 'client-call-1',
+ responseParts: [{ text: 'Memory saved' }],
+ resultDisplay: 'Success: Memory saved',
+ error: undefined,
+ },
+ tool: {
+ name: clientToolRequest.toolName,
+ description: 'Saves memory',
+ getDescription: vi.fn(),
+ } as any,
+ };
+
+ // 1. Initial render state: no tool calls
+ mockUseReactToolScheduler.mockReturnValue([
+ [],
+ mockScheduleToolCalls,
+ mockMarkToolsAsSubmitted,
+ ]);
+
+ const { result, rerender } = renderHook(() =>
+ useGeminiStream(
+ new MockedGeminiClientClass(mockConfig),
+ [],
+ mockAddItem,
+ mockSetShowHelp,
+ mockConfig,
+ mockOnDebugMessage,
+ mockHandleSlashCommand,
+ false,
+ () => 'vscode' as EditorType,
+ () => {},
+ () => Promise.resolve(),
+ ),
+ );
+
+ // --- User runs the slash command ---
+ await act(async () => {
+ await result.current.submitQuery('/memory add "test fact"');
+ });
+
+ // The command handler schedules the tool. Now we simulate the tool completing.
+ // 2. Before the next render, set the mock to return the completed tool.
+ mockUseReactToolScheduler.mockReturnValue([
+ [completedToolCall],
+ mockScheduleToolCalls,
+ mockMarkToolsAsSubmitted,
+ ]);
+
+ // 3. Trigger a re-render to process the completed tool.
+ act(() => {
+ rerender();
+ });
+
+ // --- Assert the outcome ---
+ await waitFor(() => {
+ // The tool should be marked as submitted locally
+ expect(mockMarkToolsAsSubmitted).toHaveBeenCalledWith([
+ 'client-call-1',
+ ]);
+ // Crucially, no message should be sent to the Gemini API
+ expect(mockSendMessageStream).not.toHaveBeenCalled();
+ });
+ });
+ });
+
+ describe('Memory Refresh on save_memory', () => {
+ it('should call performMemoryRefresh when a save_memory tool call completes successfully', async () => {
+ const mockPerformMemoryRefresh = vi.fn();
+ const completedToolCall: TrackedCompletedToolCall = {
+ request: {
+ callId: 'save-mem-call-1',
+ name: 'save_memory',
+ args: { fact: 'test' },
+ isClientInitiated: true,
+ },
+ status: 'success',
+ responseSubmittedToGemini: false,
+ response: {
+ callId: 'save-mem-call-1',
+ responseParts: [{ text: 'Memory saved' }],
+ resultDisplay: 'Success: Memory saved',
+ error: undefined,
+ },
+ tool: {
+ name: 'save_memory',
+ description: 'Saves memory',
+ getDescription: vi.fn(),
+ } as any,
+ };
+
+ mockUseReactToolScheduler.mockReturnValue([
+ [completedToolCall],
+ mockScheduleToolCalls,
+ mockMarkToolsAsSubmitted,
+ ]);
+
+ const { rerender } = renderHook(() =>
+ useGeminiStream(
+ new MockedGeminiClientClass(mockConfig),
+ [],
+ mockAddItem,
+ mockSetShowHelp,
+ mockConfig,
+ mockOnDebugMessage,
+ mockHandleSlashCommand,
+ false,
+ () => 'vscode' as EditorType,
+ () => {},
+ mockPerformMemoryRefresh,
+ ),
+ );
+
+ act(() => {
+ rerender();
+ });
+
+ await waitFor(() => {
+ expect(mockPerformMemoryRefresh).toHaveBeenCalledTimes(1);
+ });
+ });
+ });
});
diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts
index fcfa1c57..09b14666 100644
--- a/packages/cli/src/ui/hooks/useGeminiStream.ts
+++ b/packages/cli/src/ui/hooks/useGeminiStream.ts
@@ -89,6 +89,7 @@ export const useGeminiStream = (
shellModeActive: boolean,
getPreferredEditor: () => EditorType | undefined,
onAuthError: () => void,
+ performMemoryRefresh: () => Promise<void>,
) => {
const [initError, setInitError] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
@@ -97,6 +98,7 @@ export const useGeminiStream = (
const [thought, setThought] = useState<ThoughtSummary | null>(null);
const [pendingHistoryItemRef, setPendingHistoryItem] =
useStateAndRef<HistoryItemWithoutId | null>(null);
+ const processedMemoryToolsRef = useRef<Set<string>>(new Set());
const logger = useLogger();
const { startNewTurn, addUsage } = useSessionStats();
const gitService = useMemo(() => {
@@ -234,6 +236,7 @@ export const useGeminiStream = (
callId: `${toolName}-${Date.now()}-${Math.random().toString(16).slice(2)}`,
name: toolName,
args: toolArgs,
+ isClientInitiated: true,
};
scheduleToolCalls([toolCallRequest], abortSignal);
}
@@ -566,38 +569,77 @@ export const useGeminiStream = (
* is not already generating a response.
*/
useEffect(() => {
- if (isResponding) {
- return;
- }
+ const run = async () => {
+ if (isResponding) {
+ return;
+ }
- const completedAndReadyToSubmitTools = toolCalls.filter(
- (
- tc: TrackedToolCall,
- ): tc is TrackedCompletedToolCall | TrackedCancelledToolCall => {
- const isTerminalState =
- tc.status === 'success' ||
- tc.status === 'error' ||
- tc.status === 'cancelled';
+ const completedAndReadyToSubmitTools = toolCalls.filter(
+ (
+ tc: TrackedToolCall,
+ ): tc is TrackedCompletedToolCall | TrackedCancelledToolCall => {
+ const isTerminalState =
+ tc.status === 'success' ||
+ tc.status === 'error' ||
+ tc.status === 'cancelled';
- if (isTerminalState) {
- const completedOrCancelledCall = tc as
- | TrackedCompletedToolCall
- | TrackedCancelledToolCall;
- return (
- !completedOrCancelledCall.responseSubmittedToGemini &&
- completedOrCancelledCall.response?.responseParts !== undefined
- );
- }
- return false;
- },
- );
+ if (isTerminalState) {
+ const completedOrCancelledCall = tc as
+ | TrackedCompletedToolCall
+ | TrackedCancelledToolCall;
+ return (
+ !completedOrCancelledCall.responseSubmittedToGemini &&
+ completedOrCancelledCall.response?.responseParts !== undefined
+ );
+ }
+ return false;
+ },
+ );
+
+ // Finalize any client-initiated tools as soon as they are done.
+ const clientTools = completedAndReadyToSubmitTools.filter(
+ (t) => t.request.isClientInitiated,
+ );
+ if (clientTools.length > 0) {
+ markToolsAsSubmitted(clientTools.map((t) => t.request.callId));
+ }
+
+ // Identify new, successful save_memory calls that we haven't processed yet.
+ const newSuccessfulMemorySaves = completedAndReadyToSubmitTools.filter(
+ (t) =>
+ t.request.name === 'save_memory' &&
+ t.status === 'success' &&
+ !processedMemoryToolsRef.current.has(t.request.callId),
+ );
+
+ if (newSuccessfulMemorySaves.length > 0) {
+ // Perform the refresh only if there are new ones.
+ void performMemoryRefresh();
+ // Mark them as processed so we don't do this again on the next render.
+ newSuccessfulMemorySaves.forEach((t) =>
+ processedMemoryToolsRef.current.add(t.request.callId),
+ );
+ }
+
+ // Only proceed with submitting to Gemini if ALL tools are complete.
+ const allToolsAreComplete =
+ toolCalls.length > 0 &&
+ toolCalls.length === completedAndReadyToSubmitTools.length;
+
+ if (!allToolsAreComplete) {
+ return;
+ }
+
+ const geminiTools = completedAndReadyToSubmitTools.filter(
+ (t) => !t.request.isClientInitiated,
+ );
+
+ if (geminiTools.length === 0) {
+ return;
+ }
- if (
- completedAndReadyToSubmitTools.length > 0 &&
- completedAndReadyToSubmitTools.length === toolCalls.length
- ) {
// If all the tools were cancelled, don't submit a response to Gemini.
- const allToolsCancelled = completedAndReadyToSubmitTools.every(
+ const allToolsCancelled = geminiTools.every(
(tc) => tc.status === 'cancelled',
);
@@ -605,7 +647,7 @@ export const useGeminiStream = (
if (geminiClient) {
// We need to manually add the function responses to the history
// so the model knows the tools were cancelled.
- const responsesToAdd = completedAndReadyToSubmitTools.flatMap(
+ const responsesToAdd = geminiTools.flatMap(
(toolCall) => toolCall.response.responseParts,
);
for (const response of responsesToAdd) {
@@ -624,18 +666,17 @@ export const useGeminiStream = (
}
}
- const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map(
+ const callIdsToMarkAsSubmitted = geminiTools.map(
(toolCall) => toolCall.request.callId,
);
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
return;
}
- const responsesToSend: PartListUnion[] =
- completedAndReadyToSubmitTools.map(
- (toolCall) => toolCall.response.responseParts,
- );
- const callIdsToMarkAsSubmitted = completedAndReadyToSubmitTools.map(
+ const responsesToSend: PartListUnion[] = geminiTools.map(
+ (toolCall) => toolCall.response.responseParts,
+ );
+ const callIdsToMarkAsSubmitted = geminiTools.map(
(toolCall) => toolCall.request.callId,
);
@@ -643,7 +684,8 @@ export const useGeminiStream = (
submitQuery(mergePartListUnions(responsesToSend), {
isContinuation: true,
});
- }
+ };
+ void run();
}, [
toolCalls,
isResponding,
@@ -651,6 +693,7 @@ export const useGeminiStream = (
markToolsAsSubmitted,
addItem,
geminiClient,
+ performMemoryRefresh,
]);
const pendingHistoryItems = [