summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTommaso Sciortino <[email protected]>2025-08-21 16:49:12 -0700
committerGitHub <[email protected]>2025-08-21 23:49:12 +0000
commit15c62bade317495063c30248001694a743ad5159 (patch)
treea08014debf9100e86da2a1ab655e75915fe3355f
parent29699274bb0e8f70b9bedad40ca2d03739318853 (diff)
Reuse CoreToolScheduler for nonInteractiveToolExecutor (#6714)
-rw-r--r--packages/cli/src/nonInteractiveCli.ts28
-rw-r--r--packages/cli/src/ui/hooks/useReactToolScheduler.ts1
-rw-r--r--packages/core/src/core/coreToolScheduler.test.ts24
-rw-r--r--packages/core/src/core/coreToolScheduler.ts5
-rw-r--r--packages/core/src/core/nonInteractiveToolExecutor.test.ts122
-rw-r--r--packages/core/src/core/nonInteractiveToolExecutor.ts171
6 files changed, 93 insertions, 258 deletions
diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts
index 36337c8f..d986c1eb 100644
--- a/packages/cli/src/nonInteractiveCli.ts
+++ b/packages/cli/src/nonInteractiveCli.ts
@@ -13,7 +13,7 @@ import {
GeminiEventType,
parseAndFormatApiError,
} from '@google/gemini-cli-core';
-import { Content, Part, FunctionCall } from '@google/genai';
+import { Content, Part } from '@google/genai';
import { ConsolePatcher } from './ui/utils/ConsolePatcher.js';
import { handleAtCommand } from './ui/hooks/atCommandProcessor.js';
@@ -74,7 +74,7 @@ export async function runNonInteractive(
);
return;
}
- const functionCalls: FunctionCall[] = [];
+ const toolCallRequests: ToolCallRequestInfo[] = [];
const responseStream = geminiClient.sendMessageStream(
currentMessages[0]?.parts || [],
@@ -91,29 +91,13 @@ export async function runNonInteractive(
if (event.type === GeminiEventType.Content) {
process.stdout.write(event.value);
} else if (event.type === GeminiEventType.ToolCallRequest) {
- const toolCallRequest = event.value;
- const fc: FunctionCall = {
- name: toolCallRequest.name,
- args: toolCallRequest.args,
- id: toolCallRequest.callId,
- };
- functionCalls.push(fc);
+ toolCallRequests.push(event.value);
}
}
- if (functionCalls.length > 0) {
+ if (toolCallRequests.length > 0) {
const toolResponseParts: Part[] = [];
-
- for (const fc of functionCalls) {
- const callId = fc.id ?? `${fc.name}-${Date.now()}`;
- const requestInfo: ToolCallRequestInfo = {
- callId,
- name: fc.name as string,
- args: (fc.args ?? {}) as Record<string, unknown>,
- isClientInitiated: false,
- prompt_id,
- };
-
+ for (const requestInfo of toolCallRequests) {
const toolResponse = await executeToolCall(
config,
requestInfo,
@@ -122,7 +106,7 @@ export async function runNonInteractive(
if (toolResponse.error) {
console.error(
- `Error executing tool ${fc.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`,
+ `Error executing tool ${requestInfo.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`,
);
}
diff --git a/packages/cli/src/ui/hooks/useReactToolScheduler.ts b/packages/cli/src/ui/hooks/useReactToolScheduler.ts
index 93e05387..e4238f99 100644
--- a/packages/cli/src/ui/hooks/useReactToolScheduler.ts
+++ b/packages/cli/src/ui/hooks/useReactToolScheduler.ts
@@ -134,7 +134,6 @@ export function useReactToolScheduler(
const scheduler = useMemo(
() =>
new CoreToolScheduler({
- toolRegistry: config.getToolRegistry(),
outputUpdateHandler,
onAllToolCallsComplete: allToolCallsCompleteHandler,
onToolCallsUpdate: toolCallsUpdateHandler,
diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts
index 1c400d52..291c1862 100644
--- a/packages/core/src/core/coreToolScheduler.test.ts
+++ b/packages/core/src/core/coreToolScheduler.test.ts
@@ -129,11 +129,11 @@ describe('CoreToolScheduler', () => {
model: 'test-model',
authType: 'oauth-personal',
}),
+ getToolRegistry: () => mockToolRegistry,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
- toolRegistry: mockToolRegistry,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
@@ -189,11 +189,11 @@ describe('CoreToolScheduler with payload', () => {
model: 'test-model',
authType: 'oauth-personal',
}),
+ getToolRegistry: () => mockToolRegistry,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
- toolRegistry: mockToolRegistry,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
@@ -462,15 +462,14 @@ class MockEditTool extends BaseDeclarativeTool<
describe('CoreToolScheduler edit cancellation', () => {
it('should preserve diff when an edit is cancelled', async () => {
const mockEditTool = new MockEditTool();
- const declarativeTool = mockEditTool;
const mockToolRegistry = {
- getTool: () => declarativeTool,
+ getTool: () => mockEditTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {},
registerTool: () => {},
- getToolByName: () => declarativeTool,
- getToolByDisplayName: () => declarativeTool,
+ getToolByName: () => mockEditTool,
+ getToolByDisplayName: () => mockEditTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
@@ -489,11 +488,11 @@ describe('CoreToolScheduler edit cancellation', () => {
model: 'test-model',
authType: 'oauth-personal',
}),
+ getToolRegistry: () => mockToolRegistry,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
- toolRegistry: mockToolRegistry,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
@@ -581,11 +580,11 @@ describe('CoreToolScheduler YOLO mode', () => {
model: 'test-model',
authType: 'oauth-personal',
}),
+ getToolRegistry: () => mockToolRegistry,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
- toolRegistry: mockToolRegistry,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
@@ -670,11 +669,11 @@ describe('CoreToolScheduler request queueing', () => {
model: 'test-model',
authType: 'oauth-personal',
}),
+ getToolRegistry: () => mockToolRegistry,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
- toolRegistry: mockToolRegistry,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
@@ -783,11 +782,11 @@ describe('CoreToolScheduler request queueing', () => {
model: 'test-model',
authType: 'oauth-personal',
}),
+ getToolRegistry: () => mockToolRegistry,
} as unknown as Config;
const scheduler = new CoreToolScheduler({
config: mockConfig,
- toolRegistry: mockToolRegistry,
onAllToolCallsComplete,
onToolCallsUpdate,
getPreferredEditor: () => 'vscode',
@@ -864,7 +863,9 @@ describe('CoreToolScheduler request queueing', () => {
getTools: () => [],
discoverTools: async () => {},
discovery: {},
- };
+ } as unknown as ToolRegistry;
+
+ mockConfig.getToolRegistry = () => toolRegistry;
const onAllToolCallsComplete = vi.fn();
const onToolCallsUpdate = vi.fn();
@@ -874,7 +875,6 @@ describe('CoreToolScheduler request queueing', () => {
const scheduler = new CoreToolScheduler({
config: mockConfig,
- toolRegistry: toolRegistry as unknown as ToolRegistry,
onAllToolCallsComplete,
onToolCallsUpdate: (toolCalls) => {
onToolCallsUpdate(toolCalls);
diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts
index 5a2bb85d..a7923647 100644
--- a/packages/core/src/core/coreToolScheduler.ts
+++ b/packages/core/src/core/coreToolScheduler.ts
@@ -226,12 +226,11 @@ const createErrorResponse = (
});
interface CoreToolSchedulerOptions {
- toolRegistry: ToolRegistry;
+ config: Config;
outputUpdateHandler?: OutputUpdateHandler;
onAllToolCallsComplete?: AllToolCallsCompleteHandler;
onToolCallsUpdate?: ToolCallsUpdateHandler;
getPreferredEditor: () => EditorType | undefined;
- config: Config;
onEditorClose: () => void;
}
@@ -255,7 +254,7 @@ export class CoreToolScheduler {
constructor(options: CoreToolSchedulerOptions) {
this.config = options.config;
- this.toolRegistry = options.toolRegistry;
+ this.toolRegistry = options.config.getToolRegistry();
this.outputUpdateHandler = options.outputUpdateHandler;
this.onAllToolCallsComplete = options.onAllToolCallsComplete;
this.onToolCallsUpdate = options.onToolCallsUpdate;
diff --git a/packages/core/src/core/nonInteractiveToolExecutor.test.ts b/packages/core/src/core/nonInteractiveToolExecutor.test.ts
index 38afa697..8f16aaa7 100644
--- a/packages/core/src/core/nonInteractiveToolExecutor.test.ts
+++ b/packages/core/src/core/nonInteractiveToolExecutor.test.ts
@@ -12,6 +12,7 @@ import {
ToolResult,
Config,
ToolErrorType,
+ ApprovalMode,
} from '../index.js';
import { Part } from '@google/genai';
import { MockTool } from '../test-utils/tools.js';
@@ -27,10 +28,11 @@ describe('executeToolCall', () => {
mockToolRegistry = {
getTool: vi.fn(),
- // Add other ToolRegistry methods if needed, or use a more complete mock
} as unknown as ToolRegistry;
mockConfig = {
+ getToolRegistry: () => mockToolRegistry,
+ getApprovalMode: () => ApprovalMode.DEFAULT,
getSessionId: () => 'test-session-id',
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
@@ -38,7 +40,6 @@ describe('executeToolCall', () => {
model: 'test-model',
authType: 'oauth-personal',
}),
- getToolRegistry: () => mockToolRegistry,
} as unknown as Config;
abortController = new AbortController();
@@ -57,7 +58,7 @@ describe('executeToolCall', () => {
returnDisplay: 'Success!',
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
- vi.spyOn(mockTool, 'validateBuildAndExecute').mockResolvedValue(toolResult);
+ mockTool.executeFn.mockReturnValue(toolResult);
const response = await executeToolCall(
mockConfig,
@@ -66,18 +67,18 @@ describe('executeToolCall', () => {
);
expect(mockToolRegistry.getTool).toHaveBeenCalledWith('testTool');
- expect(mockTool.validateBuildAndExecute).toHaveBeenCalledWith(
- request.args,
- abortController.signal,
- );
- expect(response.callId).toBe('call1');
- expect(response.error).toBeUndefined();
- expect(response.resultDisplay).toBe('Success!');
- expect(response.responseParts).toEqual({
- functionResponse: {
- name: 'testTool',
- id: 'call1',
- response: { output: 'Tool executed successfully' },
+ expect(mockTool.executeFn).toHaveBeenCalledWith(request.args);
+ expect(response).toStrictEqual({
+ callId: 'call1',
+ error: undefined,
+ errorType: undefined,
+ resultDisplay: 'Success!',
+ responseParts: {
+ functionResponse: {
+ name: 'testTool',
+ id: 'call1',
+ response: { output: 'Tool executed successfully' },
+ },
},
});
});
@@ -98,23 +99,19 @@ describe('executeToolCall', () => {
abortController.signal,
);
- expect(response.callId).toBe('call2');
- expect(response.error).toBeInstanceOf(Error);
- expect(response.error?.message).toBe(
- 'Tool "nonexistentTool" not found in registry.',
- );
- expect(response.resultDisplay).toBe(
- 'Tool "nonexistentTool" not found in registry.',
- );
- expect(response.responseParts).toEqual([
- {
+ expect(response).toStrictEqual({
+ callId: 'call2',
+ error: new Error('Tool "nonexistentTool" not found in registry.'),
+ errorType: ToolErrorType.TOOL_NOT_REGISTERED,
+ resultDisplay: 'Tool "nonexistentTool" not found in registry.',
+ responseParts: {
functionResponse: {
name: 'nonexistentTool',
id: 'call2',
response: { error: 'Tool "nonexistentTool" not found in registry.' },
},
},
- ]);
+ });
});
it('should return an error if tool validation fails', async () => {
@@ -125,24 +122,17 @@ describe('executeToolCall', () => {
isClientInitiated: false,
prompt_id: 'prompt-id-3',
};
- const validationErrorResult: ToolResult = {
- llmContent: 'Error: Invalid parameters',
- returnDisplay: 'Invalid parameters',
- error: {
- message: 'Invalid parameters',
- type: ToolErrorType.INVALID_TOOL_PARAMS,
- },
- };
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
- vi.spyOn(mockTool, 'validateBuildAndExecute').mockResolvedValue(
- validationErrorResult,
- );
+ vi.spyOn(mockTool, 'build').mockImplementation(() => {
+ throw new Error('Invalid parameters');
+ });
const response = await executeToolCall(
mockConfig,
request,
abortController.signal,
);
+
expect(response).toStrictEqual({
callId: 'call3',
error: new Error('Invalid parameters'),
@@ -152,7 +142,7 @@ describe('executeToolCall', () => {
id: 'call3',
name: 'testTool',
response: {
- output: 'Error: Invalid parameters',
+ error: 'Invalid parameters',
},
},
},
@@ -177,9 +167,7 @@ describe('executeToolCall', () => {
},
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
- vi.spyOn(mockTool, 'validateBuildAndExecute').mockResolvedValue(
- executionErrorResult,
- );
+ mockTool.executeFn.mockReturnValue(executionErrorResult);
const response = await executeToolCall(
mockConfig,
@@ -195,7 +183,7 @@ describe('executeToolCall', () => {
id: 'call4',
name: 'testTool',
response: {
- output: 'Error: Execution failed',
+ error: 'Execution failed',
},
},
},
@@ -211,11 +199,10 @@ describe('executeToolCall', () => {
isClientInitiated: false,
prompt_id: 'prompt-id-5',
};
- const executionError = new Error('Something went very wrong');
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
- vi.spyOn(mockTool, 'validateBuildAndExecute').mockRejectedValue(
- executionError,
- );
+ mockTool.executeFn.mockImplementation(() => {
+ throw new Error('Something went very wrong');
+ });
const response = await executeToolCall(
mockConfig,
@@ -223,19 +210,19 @@ describe('executeToolCall', () => {
abortController.signal,
);
- expect(response.callId).toBe('call5');
- expect(response.error).toBe(executionError);
- expect(response.errorType).toBe(ToolErrorType.UNHANDLED_EXCEPTION);
- expect(response.resultDisplay).toBe('Something went very wrong');
- expect(response.responseParts).toEqual([
- {
+ expect(response).toStrictEqual({
+ callId: 'call5',
+ error: new Error('Something went very wrong'),
+ errorType: ToolErrorType.UNHANDLED_EXCEPTION,
+ resultDisplay: 'Something went very wrong',
+ responseParts: {
functionResponse: {
name: 'testTool',
id: 'call5',
response: { error: 'Something went very wrong' },
},
},
- ]);
+ });
});
it('should correctly format llmContent with inlineData', async () => {
@@ -254,7 +241,7 @@ describe('executeToolCall', () => {
returnDisplay: 'Image processed',
};
vi.mocked(mockToolRegistry.getTool).mockReturnValue(mockTool);
- vi.spyOn(mockTool, 'validateBuildAndExecute').mockResolvedValue(toolResult);
+ mockTool.executeFn.mockReturnValue(toolResult);
const response = await executeToolCall(
mockConfig,
@@ -262,18 +249,23 @@ describe('executeToolCall', () => {
abortController.signal,
);
- expect(response.resultDisplay).toBe('Image processed');
- expect(response.responseParts).toEqual([
- {
- functionResponse: {
- name: 'testTool',
- id: 'call6',
- response: {
- output: 'Binary content of type image/png was processed.',
+ expect(response).toStrictEqual({
+ callId: 'call6',
+ error: undefined,
+ errorType: undefined,
+ resultDisplay: 'Image processed',
+ responseParts: [
+ {
+ functionResponse: {
+ name: 'testTool',
+ id: 'call6',
+ response: {
+ output: 'Binary content of type image/png was processed.',
+ },
},
},
- },
- imageDataPart,
- ]);
+ imageDataPart,
+ ],
+ });
});
});
diff --git a/packages/core/src/core/nonInteractiveToolExecutor.ts b/packages/core/src/core/nonInteractiveToolExecutor.ts
index c116ca33..46ca71d2 100644
--- a/packages/core/src/core/nonInteractiveToolExecutor.ts
+++ b/packages/core/src/core/nonInteractiveToolExecutor.ts
@@ -4,166 +4,27 @@
* SPDX-License-Identifier: Apache-2.0
*/
-import {
- FileDiff,
- logToolCall,
- ToolCallRequestInfo,
- ToolCallResponseInfo,
- ToolErrorType,
- ToolResult,
-} from '../index.js';
-import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
-import { Config } from '../config/config.js';
-import { convertToFunctionResponse } from './coreToolScheduler.js';
-import { ToolCallDecision } from '../telemetry/tool-call-decision.js';
+import { ToolCallRequestInfo, ToolCallResponseInfo, Config } from '../index.js';
+import { CoreToolScheduler } from './coreToolScheduler.js';
/**
- * Executes a single tool call non-interactively.
- * It does not handle confirmations, multiple calls, or live updates.
+ * Executes a single tool call non-interactively by leveraging the CoreToolScheduler.
*/
export async function executeToolCall(
config: Config,
toolCallRequest: ToolCallRequestInfo,
- abortSignal?: AbortSignal,
+ abortSignal: AbortSignal,
): Promise<ToolCallResponseInfo> {
- const tool = config.getToolRegistry().getTool(toolCallRequest.name);
-
- const startTime = Date.now();
- if (!tool) {
- const error = new Error(
- `Tool "${toolCallRequest.name}" not found in registry.`,
- );
- const durationMs = Date.now() - startTime;
- logToolCall(config, {
- 'event.name': 'tool_call',
- 'event.timestamp': new Date().toISOString(),
- function_name: toolCallRequest.name,
- function_args: toolCallRequest.args,
- duration_ms: durationMs,
- success: false,
- error: error.message,
- prompt_id: toolCallRequest.prompt_id,
- tool_type: 'native',
- });
- // Ensure the response structure matches what the API expects for an error
- return {
- callId: toolCallRequest.callId,
- responseParts: [
- {
- functionResponse: {
- id: toolCallRequest.callId,
- name: toolCallRequest.name,
- response: { error: error.message },
- },
- },
- ],
- resultDisplay: error.message,
- error,
- errorType: ToolErrorType.TOOL_NOT_REGISTERED,
- };
- }
-
- try {
- // Directly execute without confirmation or live output handling
- const effectiveAbortSignal = abortSignal ?? new AbortController().signal;
- const toolResult: ToolResult = await tool.validateBuildAndExecute(
- toolCallRequest.args,
- effectiveAbortSignal,
- // No live output callback for non-interactive mode
- );
-
- const tool_output = toolResult.llmContent;
-
- const tool_display = toolResult.returnDisplay;
-
- // eslint-disable-next-line @typescript-eslint/no-explicit-any
- let metadata: { [key: string]: any } = {};
- if (
- toolResult.error === undefined &&
- typeof tool_display === 'object' &&
- tool_display !== null &&
- 'diffStat' in tool_display
- ) {
- const diffStat = (tool_display as FileDiff).diffStat;
- if (diffStat) {
- metadata = {
- ai_added_lines: diffStat.ai_added_lines,
- ai_removed_lines: diffStat.ai_removed_lines,
- user_added_lines: diffStat.user_added_lines,
- user_removed_lines: diffStat.user_removed_lines,
- };
- }
- }
- const durationMs = Date.now() - startTime;
- logToolCall(config, {
- 'event.name': 'tool_call',
- 'event.timestamp': new Date().toISOString(),
- function_name: toolCallRequest.name,
- function_args: toolCallRequest.args,
- duration_ms: durationMs,
- success: toolResult.error === undefined,
- error:
- toolResult.error === undefined ? undefined : toolResult.error.message,
- error_type:
- toolResult.error === undefined ? undefined : toolResult.error.type,
- prompt_id: toolCallRequest.prompt_id,
- metadata,
- decision: ToolCallDecision.AUTO_ACCEPT,
- tool_type:
- typeof tool !== 'undefined' && tool instanceof DiscoveredMCPTool
- ? 'mcp'
- : 'native',
- });
-
- const response = convertToFunctionResponse(
- toolCallRequest.name,
- toolCallRequest.callId,
- tool_output,
- );
-
- return {
- callId: toolCallRequest.callId,
- responseParts: response,
- resultDisplay: tool_display,
- error:
- toolResult.error === undefined
- ? undefined
- : new Error(toolResult.error.message),
- errorType:
- toolResult.error === undefined ? undefined : toolResult.error.type,
- };
- } catch (e) {
- const error = e instanceof Error ? e : new Error(String(e));
- const durationMs = Date.now() - startTime;
- logToolCall(config, {
- 'event.name': 'tool_call',
- 'event.timestamp': new Date().toISOString(),
- function_name: toolCallRequest.name,
- function_args: toolCallRequest.args,
- duration_ms: durationMs,
- success: false,
- error: error.message,
- error_type: ToolErrorType.UNHANDLED_EXCEPTION,
- prompt_id: toolCallRequest.prompt_id,
- tool_type:
- typeof tool !== 'undefined' && tool instanceof DiscoveredMCPTool
- ? 'mcp'
- : 'native',
- });
- return {
- callId: toolCallRequest.callId,
- responseParts: [
- {
- functionResponse: {
- id: toolCallRequest.callId,
- name: toolCallRequest.name,
- response: { error: error.message },
- },
- },
- ],
- resultDisplay: error.message,
- error,
- errorType: ToolErrorType.UNHANDLED_EXCEPTION,
- };
- }
+ return new Promise<ToolCallResponseInfo>((resolve, reject) => {
+ new CoreToolScheduler({
+ config,
+ getPreferredEditor: () => undefined,
+ onEditorClose: () => {},
+ onAllToolCallsComplete: async (completedToolCalls) => {
+ resolve(completedToolCalls[0].response);
+ },
+ })
+ .schedule(toolCallRequest, abortSignal)
+ .catch(reject);
+ });
}